summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--client/client.go25
-rw-r--r--client/server.go24
-rw-r--r--exports/exports.go40
-rw-r--r--exports/server.h47
-rw-r--r--exports/servers.go44
-rw-r--r--internal/server/server.go1
-rw-r--r--wrappers/python/eduvpn_common/loader.py6
-rw-r--r--wrappers/python/eduvpn_common/main.py34
-rw-r--r--wrappers/python/eduvpn_common/server.py14
-rw-r--r--wrappers/python/eduvpn_common/types.py1
10 files changed, 188 insertions, 48 deletions
diff --git a/client/client.go b/client/client.go
index 2bb0cc2..89db8f5 100644
--- a/client/client.go
+++ b/client/client.go
@@ -11,6 +11,7 @@ import (
"github.com/eduvpn/eduvpn-common/internal/fsm"
"github.com/eduvpn/eduvpn-common/internal/http"
"github.com/eduvpn/eduvpn-common/internal/log"
+ "github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/server"
"github.com/eduvpn/eduvpn-common/internal/util"
"github.com/eduvpn/eduvpn-common/types"
@@ -116,6 +117,30 @@ type Client struct {
// The Failover monitor for the current VPN connection
Failover *failover.DroppedConMon `json:"-"`
+
+ // tokenCB is the token callback
+ tokenCB func(srv server.Server, tok oauth.Token)
+}
+
+func (c *Client) ForwardTokenUpdate(srv server.Server) {
+ if c.tokenCB == nil {
+ log.Logger.Debugf("No token update callback available")
+ return
+ }
+ t := oauth.Token{}
+ o := srv.OAuth()
+ if o != nil {
+ t = o.Token()
+ } else {
+ log.Logger.Debugf("OAuth was nil when forwarding token callback")
+ }
+ log.Logger.Debugf("Running token callback")
+ c.tokenCB(srv, t)
+}
+
+// SetTokenUpdater sets the token updater callback
+func (c *Client) SetTokenUpdater(updater func(srv server.Server, tok oauth.Token)) {
+ c.tokenCB = updater
}
// Register initializes the clientwith the following parameters:
diff --git a/client/server.go b/client/server.go
index 55a3b9e..5578f04 100644
--- a/client/server.go
+++ b/client/server.go
@@ -108,6 +108,8 @@ func (c *Client) Cleanup(ct oauth.Token) error {
if server.NeedsRelogin(srv) {
server.UpdateTokens(srv, ct)
}
+ // update tokens to client
+ defer c.ForwardTokenUpdate(srv)
// Do the /disconnect API call
err = server.Disconnect(srv)
if err != nil {
@@ -260,6 +262,9 @@ func (c *Client) AddInstituteServer(url string) (srv server.Server, err error) {
}
c.FSM.GoTransitionWithData(StateNoServer, c.Servers)
+
+ // Also forward tokens using the callback
+ c.ForwardTokenUpdate(srv)
return srv, nil
}
@@ -322,6 +327,9 @@ func (c *Client) AddSecureInternetHomeServer(orgID string) (srv server.Server, e
return nil, err
}
c.FSM.GoTransitionWithData(StateNoServer, c.Servers)
+
+ // Also forward tokens using the callback
+ c.ForwardTokenUpdate(srv)
return srv, nil
}
@@ -369,6 +377,9 @@ func (c *Client) AddCustomServer(url string) (srv server.Server, err error) {
}
c.FSM.GoTransitionWithData(StateNoServer, c.Servers)
+
+ // Also forward tokens using the callback
+ c.ForwardTokenUpdate(srv)
return srv, nil
}
@@ -408,6 +419,9 @@ func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool, t oauth.To
c.goBackInternal()
}
+ // Also forward tokens using the callback
+ c.ForwardTokenUpdate(srv)
+
return cfg, err
}
@@ -446,6 +460,9 @@ func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool, t oauth.T
c.goBackInternal()
}
+ // Also forward tokens using the callback
+ c.ForwardTokenUpdate(srv)
+
return cfg, err
}
@@ -483,6 +500,9 @@ func (c *Client) GetConfigCustomServer(url string, preferTCP bool, t oauth.Token
c.goBackInternal()
}
+ // Also forward tokens using the callback
+ c.ForwardTokenUpdate(srv)
+
return cfg, err
}
@@ -543,7 +563,9 @@ func (c *Client) RenewSession() (err error) {
}
server.MarkTokensForRenew(srv)
- return c.ensureLogin(srv, oauth.Token{})
+ c.ForwardTokenUpdate(srv)
+ err = c.ensureLogin(srv, oauth.Token{})
+ return err
}
// ShouldRenewButton returns true if the renew button should be shown
diff --git a/exports/exports.go b/exports/exports.go
index d9ec122..a11c0fa 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -3,6 +3,7 @@ package main
/*
#include <stdlib.h>
#include "error.h"
+#include "server.h"
typedef long long int (*ReadRxBytes)();
typedef struct token {
@@ -11,6 +12,8 @@ typedef struct token {
unsigned long long int expired;
} token;
+typedef void (*UpdateToken)(const char* name, server* srv, token* tok);
+
typedef struct configData {
const char* config;
const char* config_type;
@@ -23,6 +26,12 @@ static long long int get_read_rx_bytes(ReadRxBytes read)
{
return read();
}
+
+static void update_token(UpdateToken func, const char* name, server* srv, token* tok)
+{
+ func(name, srv, tok);
+}
+
static int call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data)
{
return callback(name, oldstate, newstate, data);
@@ -34,6 +43,8 @@ import (
"time"
"unsafe"
+ "github.com/eduvpn/eduvpn-common/internal/log"
+ "github.com/eduvpn/eduvpn-common/internal/server"
"github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/go-errors/errors"
@@ -145,6 +156,28 @@ func Register(
return getError(registerErr)
}
+//export SetTokenUpdater
+func SetTokenUpdater(name *C.char, updater C.UpdateToken) *C.error {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ if stateErr != nil {
+ return getError(stateErr)
+ }
+ state.SetTokenUpdater(func(srv server.Server, tok oauth.Token) {
+ b, err := srv.Base()
+ if err != nil {
+ log.Logger.Warningf("No server base found for token updating with error: %v", err)
+ return
+ }
+ cName := C.CString(nameStr)
+ cSrv := getCPtrServer(state, b)
+ cTok := cToken(tok)
+ C.update_token(updater, cName, cSrv, cTok)
+ FreeString(cName)
+ })
+ return nil
+}
+
//export Deregister
func Deregister(name *C.char) *C.error {
nameStr := C.GoString(name)
@@ -284,6 +317,13 @@ func cConfig(config *client.ConfigData) *C.configData {
return cConf
}
+//export FreeTokens
+func FreeTokens(tokens *C.token) {
+ C.free(unsafe.Pointer(tokens.access))
+ C.free(unsafe.Pointer(tokens.refresh))
+ C.free(unsafe.Pointer(tokens))
+}
+
//export FreeConfig
func FreeConfig(config *C.configData) {
C.free(unsafe.Pointer(config.config))
diff --git a/exports/server.h b/exports/server.h
new file mode 100644
index 0000000..4bc8a16
--- /dev/null
+++ b/exports/server.h
@@ -0,0 +1,47 @@
+#ifndef SERVER_H
+#define SERVER_H
+
+// The struct for a single server profile
+typedef struct serverProfile {
+ const char* id;
+ const char* display_name;
+ //const char* proto_list;
+ int default_gateway;
+} serverProfile;
+
+// The struct for all server profiles
+typedef struct serverProfiles {
+ int current;
+ serverProfile** profiles;
+ size_t total_profiles;
+} serverProfiles;
+
+// The struct for server locations
+typedef struct serverLocations {
+ const char** locations;
+ size_t total_locations;
+} serverLocations;
+
+// The struct for a single server
+typedef struct server {
+ const char* identifier;
+ const char* display_name;
+ const char* server_type;
+ const char* country_code;
+ const char** support_contact;
+ size_t total_support_contact;
+ serverLocations* locations;
+ serverProfiles* profiles;
+ unsigned long long int expire_time;
+} server;
+
+// The struct for all servers
+typedef struct servers {
+ server** custom_servers;
+ size_t total_custom;
+ server** institute_servers;
+ size_t total_institute;
+ server* secure_internet_server;
+} servers;
+
+#endif /* GRANDPARENT_H */
diff --git a/exports/servers.go b/exports/servers.go
index 662808c..73b8b6c 100644
--- a/exports/servers.go
+++ b/exports/servers.go
@@ -4,49 +4,7 @@ package main
// for free and size_t
#include <stdlib.h>
#include "error.h"
-
-// The struct for a single server profile
-typedef struct serverProfile {
- const char* id;
- const char* display_name;
- //const char* proto_list;
- int default_gateway;
-} serverProfile;
-
-// The struct for all server profiles
-typedef struct serverProfiles {
- int current;
- serverProfile** profiles;
- size_t total_profiles;
-} serverProfiles;
-
-// The struct for server locations
-typedef struct serverLocations {
- const char** locations;
- size_t total_locations;
-} serverLocations;
-
-// The struct for a single server
-typedef struct server {
- const char* identifier;
- const char* display_name;
- const char* server_type;
- const char* country_code;
- const char** support_contact;
- size_t total_support_contact;
- serverLocations* locations;
- serverProfiles* profiles;
- unsigned long long int expire_time;
-} server;
-
-// The struct for all servers
-typedef struct servers {
- server** custom_servers;
- size_t total_custom;
- server** institute_servers;
- size_t total_institute;
- server* secure_internet_server;
-} servers;
+#include "server.h"
*/
import "C"
diff --git a/internal/server/server.go b/internal/server/server.go
index daaa7a6..95c249e 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -177,7 +177,6 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi
}
t := oauth.Token{}
-
o := srv.OAuth()
if o != nil {
t = o.Token()
diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py
index 036c201..673d180 100644
--- a/wrappers/python/eduvpn_common/loader.py
+++ b/wrappers/python/eduvpn_common/loader.py
@@ -8,6 +8,7 @@ from eduvpn_common.types import (
cToken,
DataError,
ReadRxBytes,
+ UpdateToken,
VPNStateChange,
)
@@ -68,6 +69,7 @@ def initialize_functions(lib: CDLL) -> None:
], c_void_p
lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None
lib.FreeConfig.argtypes, lib.FreeConfig.restype = [c_void_p], None
+ lib.FreeTokens.argtypes, lib.FreeTokens.restype = [c_void_p], None
lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [
c_void_p
], None
@@ -112,6 +114,10 @@ def initialize_functions(lib: CDLL) -> None:
VPNStateChange,
c_int,
], c_void_p
+ lib.SetTokenUpdater.argtypes, lib.SetTokenUpdater.restype = [
+ c_char_p,
+ UpdateToken,
+ ], c_void_p
lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [
c_char_p,
c_char_p,
diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py
index 3ca26fe..74ff52d 100644
--- a/wrappers/python/eduvpn_common/main.py
+++ b/wrappers/python/eduvpn_common/main.py
@@ -1,5 +1,5 @@
import threading
-from ctypes import cast, c_void_p, c_int, pointer
+from ctypes import POINTER, cast, c_void_p, c_int, pointer
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from eduvpn_common.discovery import (
@@ -17,6 +17,7 @@ from eduvpn_common.server import (
encode_tokens,
get_config,
Server,
+ get_tokens,
get_transition_server,
get_servers,
)
@@ -24,7 +25,7 @@ from eduvpn_common.state import State, StateType
from eduvpn_common.types import (
VPNStateChange,
ReadRxBytes,
- cToken,
+ UpdateToken,
decode_res,
encode_args,
get_data_error,
@@ -54,6 +55,7 @@ class EduVPN(object):
initialize_functions(self.lib)
self.event_handler = EventHandler(self.lib)
+ self.token_callback = None
# Callbacks that need to wait for specific events
@@ -140,6 +142,16 @@ class EduVPN(object):
if register_err:
raise register_err
+ def set_token_updater(self, updater: Callable):
+ self.token_callback = updater
+ updater_err = self.go_function(
+ self.lib.SetTokenUpdater,
+ token_callback,
+ )
+
+ if updater_err:
+ raise updater_err
+
def get_disco_servers(self) -> Optional[DiscoServers]:
"""Get the discovery servers
@@ -442,6 +454,11 @@ class EduVPN(object):
"""
return self.event.run(old_state, new_state, data)
+ def token_calback(self, srv: Server, tok: Token):
+ if self.token_callback is None:
+ return
+ self.token_callback(srv, tok)
+
def set_profile(self, profile_id: str) -> None:
"""Set the profile of the current server
@@ -585,6 +602,19 @@ class EduVPN(object):
eduvpn_objects: Dict[str, EduVPN] = {}
+@UpdateToken
+def token_callback(name: bytes, srv, tok):
+ name_decoded = name.decode()
+ if name_decoded not in eduvpn_objects:
+ return 0
+ obj = eduvpn_objects[name_decoded]
+ srv_conv = get_transition_server(obj.lib, srv)
+ tok_conv = get_tokens(obj.lib, tok)
+ obj.token_callback(
+ srv_conv, tok_conv
+ )
+
+
@VPNStateChange
def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> int:
"""The internal callback that is passed to the Go library
diff --git a/wrappers/python/eduvpn_common/server.py b/wrappers/python/eduvpn_common/server.py
index 068dc61..55eadcd 100644
--- a/wrappers/python/eduvpn_common/server.py
+++ b/wrappers/python/eduvpn_common/server.py
@@ -392,6 +392,19 @@ def get_locations(lib: CDLL, ptr: c_void_p) -> Optional[List[str]]:
return None
+def get_tokens(lib: CDLL, ptr: c_void_p) -> Optional[Token]:
+ if ptr:
+ toks = cast(ptr, POINTER(cToken)).contents
+ access = toks.access.decode("utf-8")
+ refresh = toks.refresh.decode("utf-8")
+ expired = toks.expired
+ lib.FreeTokens(ptr)
+ return Token(
+ access, refresh, expired
+ )
+ return None
+
+
def get_config(lib: CDLL, ptr: c_void_p) -> Optional[Config]:
"""Get the config from the Go library as a C structure and return a Python usable structure
@@ -403,7 +416,6 @@ def get_config(lib: CDLL, ptr: c_void_p) -> Optional[Config]:
:return: The configuration if there is any
:rtype: Optional[Config]
"""
- # TODO: FREE
if ptr:
config = cast(ptr, POINTER(cConfig)).contents
cfg = config.config.decode("utf-8")
diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py
index 4bc5a85..32a7a00 100644
--- a/wrappers/python/eduvpn_common/types.py
+++ b/wrappers/python/eduvpn_common/types.py
@@ -195,6 +195,7 @@ class DataError(Structure):
# The type for a Go state change callback
VPNStateChange = CFUNCTYPE(c_int, c_char_p, c_int, c_int, c_void_p)
ReadRxBytes = CFUNCTYPE(c_ulonglong)
+UpdateToken = CFUNCTYPE(None, c_char_p, c_void_p, c_void_p)
def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]: