diff options
| -rw-r--r-- | client/client.go | 25 | ||||
| -rw-r--r-- | client/server.go | 24 | ||||
| -rw-r--r-- | exports/exports.go | 40 | ||||
| -rw-r--r-- | exports/server.h | 47 | ||||
| -rw-r--r-- | exports/servers.go | 44 | ||||
| -rw-r--r-- | internal/server/server.go | 1 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/loader.py | 6 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/main.py | 34 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/server.py | 14 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/types.py | 1 |
10 files changed, 188 insertions, 48 deletions
diff --git a/client/client.go b/client/client.go index 2bb0cc2..89db8f5 100644 --- a/client/client.go +++ b/client/client.go @@ -11,6 +11,7 @@ import ( "github.com/eduvpn/eduvpn-common/internal/fsm" "github.com/eduvpn/eduvpn-common/internal/http" "github.com/eduvpn/eduvpn-common/internal/log" + "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" "github.com/eduvpn/eduvpn-common/internal/util" "github.com/eduvpn/eduvpn-common/types" @@ -116,6 +117,30 @@ type Client struct { // The Failover monitor for the current VPN connection Failover *failover.DroppedConMon `json:"-"` + + // tokenCB is the token callback + tokenCB func(srv server.Server, tok oauth.Token) +} + +func (c *Client) ForwardTokenUpdate(srv server.Server) { + if c.tokenCB == nil { + log.Logger.Debugf("No token update callback available") + return + } + t := oauth.Token{} + o := srv.OAuth() + if o != nil { + t = o.Token() + } else { + log.Logger.Debugf("OAuth was nil when forwarding token callback") + } + log.Logger.Debugf("Running token callback") + c.tokenCB(srv, t) +} + +// SetTokenUpdater sets the token updater callback +func (c *Client) SetTokenUpdater(updater func(srv server.Server, tok oauth.Token)) { + c.tokenCB = updater } // Register initializes the clientwith the following parameters: diff --git a/client/server.go b/client/server.go index 55a3b9e..5578f04 100644 --- a/client/server.go +++ b/client/server.go @@ -108,6 +108,8 @@ func (c *Client) Cleanup(ct oauth.Token) error { if server.NeedsRelogin(srv) { server.UpdateTokens(srv, ct) } + // update tokens to client + defer c.ForwardTokenUpdate(srv) // Do the /disconnect API call err = server.Disconnect(srv) if err != nil { @@ -260,6 +262,9 @@ func (c *Client) AddInstituteServer(url string) (srv server.Server, err error) { } c.FSM.GoTransitionWithData(StateNoServer, c.Servers) + + // Also forward tokens using the callback + c.ForwardTokenUpdate(srv) return srv, nil } @@ -322,6 +327,9 @@ func (c *Client) AddSecureInternetHomeServer(orgID string) (srv server.Server, e return nil, err } c.FSM.GoTransitionWithData(StateNoServer, c.Servers) + + // Also forward tokens using the callback + c.ForwardTokenUpdate(srv) return srv, nil } @@ -369,6 +377,9 @@ func (c *Client) AddCustomServer(url string) (srv server.Server, err error) { } c.FSM.GoTransitionWithData(StateNoServer, c.Servers) + + // Also forward tokens using the callback + c.ForwardTokenUpdate(srv) return srv, nil } @@ -408,6 +419,9 @@ func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool, t oauth.To c.goBackInternal() } + // Also forward tokens using the callback + c.ForwardTokenUpdate(srv) + return cfg, err } @@ -446,6 +460,9 @@ func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool, t oauth.T c.goBackInternal() } + // Also forward tokens using the callback + c.ForwardTokenUpdate(srv) + return cfg, err } @@ -483,6 +500,9 @@ func (c *Client) GetConfigCustomServer(url string, preferTCP bool, t oauth.Token c.goBackInternal() } + // Also forward tokens using the callback + c.ForwardTokenUpdate(srv) + return cfg, err } @@ -543,7 +563,9 @@ func (c *Client) RenewSession() (err error) { } server.MarkTokensForRenew(srv) - return c.ensureLogin(srv, oauth.Token{}) + c.ForwardTokenUpdate(srv) + err = c.ensureLogin(srv, oauth.Token{}) + return err } // ShouldRenewButton returns true if the renew button should be shown diff --git a/exports/exports.go b/exports/exports.go index d9ec122..a11c0fa 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -3,6 +3,7 @@ package main /* #include <stdlib.h> #include "error.h" +#include "server.h" typedef long long int (*ReadRxBytes)(); typedef struct token { @@ -11,6 +12,8 @@ typedef struct token { unsigned long long int expired; } token; +typedef void (*UpdateToken)(const char* name, server* srv, token* tok); + typedef struct configData { const char* config; const char* config_type; @@ -23,6 +26,12 @@ static long long int get_read_rx_bytes(ReadRxBytes read) { return read(); } + +static void update_token(UpdateToken func, const char* name, server* srv, token* tok) +{ + func(name, srv, tok); +} + static int call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data) { return callback(name, oldstate, newstate, data); @@ -34,6 +43,8 @@ import ( "time" "unsafe" + "github.com/eduvpn/eduvpn-common/internal/log" + "github.com/eduvpn/eduvpn-common/internal/server" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/go-errors/errors" @@ -145,6 +156,28 @@ func Register( return getError(registerErr) } +//export SetTokenUpdater +func SetTokenUpdater(name *C.char, updater C.UpdateToken) *C.error { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + if stateErr != nil { + return getError(stateErr) + } + state.SetTokenUpdater(func(srv server.Server, tok oauth.Token) { + b, err := srv.Base() + if err != nil { + log.Logger.Warningf("No server base found for token updating with error: %v", err) + return + } + cName := C.CString(nameStr) + cSrv := getCPtrServer(state, b) + cTok := cToken(tok) + C.update_token(updater, cName, cSrv, cTok) + FreeString(cName) + }) + return nil +} + //export Deregister func Deregister(name *C.char) *C.error { nameStr := C.GoString(name) @@ -284,6 +317,13 @@ func cConfig(config *client.ConfigData) *C.configData { return cConf } +//export FreeTokens +func FreeTokens(tokens *C.token) { + C.free(unsafe.Pointer(tokens.access)) + C.free(unsafe.Pointer(tokens.refresh)) + C.free(unsafe.Pointer(tokens)) +} + //export FreeConfig func FreeConfig(config *C.configData) { C.free(unsafe.Pointer(config.config)) diff --git a/exports/server.h b/exports/server.h new file mode 100644 index 0000000..4bc8a16 --- /dev/null +++ b/exports/server.h @@ -0,0 +1,47 @@ +#ifndef SERVER_H +#define SERVER_H + +// The struct for a single server profile +typedef struct serverProfile { + const char* id; + const char* display_name; + //const char* proto_list; + int default_gateway; +} serverProfile; + +// The struct for all server profiles +typedef struct serverProfiles { + int current; + serverProfile** profiles; + size_t total_profiles; +} serverProfiles; + +// The struct for server locations +typedef struct serverLocations { + const char** locations; + size_t total_locations; +} serverLocations; + +// The struct for a single server +typedef struct server { + const char* identifier; + const char* display_name; + const char* server_type; + const char* country_code; + const char** support_contact; + size_t total_support_contact; + serverLocations* locations; + serverProfiles* profiles; + unsigned long long int expire_time; +} server; + +// The struct for all servers +typedef struct servers { + server** custom_servers; + size_t total_custom; + server** institute_servers; + size_t total_institute; + server* secure_internet_server; +} servers; + +#endif /* GRANDPARENT_H */ diff --git a/exports/servers.go b/exports/servers.go index 662808c..73b8b6c 100644 --- a/exports/servers.go +++ b/exports/servers.go @@ -4,49 +4,7 @@ package main // for free and size_t #include <stdlib.h> #include "error.h" - -// The struct for a single server profile -typedef struct serverProfile { - const char* id; - const char* display_name; - //const char* proto_list; - int default_gateway; -} serverProfile; - -// The struct for all server profiles -typedef struct serverProfiles { - int current; - serverProfile** profiles; - size_t total_profiles; -} serverProfiles; - -// The struct for server locations -typedef struct serverLocations { - const char** locations; - size_t total_locations; -} serverLocations; - -// The struct for a single server -typedef struct server { - const char* identifier; - const char* display_name; - const char* server_type; - const char* country_code; - const char** support_contact; - size_t total_support_contact; - serverLocations* locations; - serverProfiles* profiles; - unsigned long long int expire_time; -} server; - -// The struct for all servers -typedef struct servers { - server** custom_servers; - size_t total_custom; - server** institute_servers; - size_t total_institute; - server* secure_internet_server; -} servers; +#include "server.h" */ import "C" diff --git a/internal/server/server.go b/internal/server/server.go index daaa7a6..95c249e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -177,7 +177,6 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi } t := oauth.Token{} - o := srv.OAuth() if o != nil { t = o.Token() diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py index 036c201..673d180 100644 --- a/wrappers/python/eduvpn_common/loader.py +++ b/wrappers/python/eduvpn_common/loader.py @@ -8,6 +8,7 @@ from eduvpn_common.types import ( cToken, DataError, ReadRxBytes, + UpdateToken, VPNStateChange, ) @@ -68,6 +69,7 @@ def initialize_functions(lib: CDLL) -> None: ], c_void_p lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None lib.FreeConfig.argtypes, lib.FreeConfig.restype = [c_void_p], None + lib.FreeTokens.argtypes, lib.FreeTokens.restype = [c_void_p], None lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [ c_void_p ], None @@ -112,6 +114,10 @@ def initialize_functions(lib: CDLL) -> None: VPNStateChange, c_int, ], c_void_p + lib.SetTokenUpdater.argtypes, lib.SetTokenUpdater.restype = [ + c_char_p, + UpdateToken, + ], c_void_p lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [ c_char_p, c_char_p, diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index 3ca26fe..74ff52d 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -1,5 +1,5 @@ import threading -from ctypes import cast, c_void_p, c_int, pointer +from ctypes import POINTER, cast, c_void_p, c_int, pointer from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from eduvpn_common.discovery import ( @@ -17,6 +17,7 @@ from eduvpn_common.server import ( encode_tokens, get_config, Server, + get_tokens, get_transition_server, get_servers, ) @@ -24,7 +25,7 @@ from eduvpn_common.state import State, StateType from eduvpn_common.types import ( VPNStateChange, ReadRxBytes, - cToken, + UpdateToken, decode_res, encode_args, get_data_error, @@ -54,6 +55,7 @@ class EduVPN(object): initialize_functions(self.lib) self.event_handler = EventHandler(self.lib) + self.token_callback = None # Callbacks that need to wait for specific events @@ -140,6 +142,16 @@ class EduVPN(object): if register_err: raise register_err + def set_token_updater(self, updater: Callable): + self.token_callback = updater + updater_err = self.go_function( + self.lib.SetTokenUpdater, + token_callback, + ) + + if updater_err: + raise updater_err + def get_disco_servers(self) -> Optional[DiscoServers]: """Get the discovery servers @@ -442,6 +454,11 @@ class EduVPN(object): """ return self.event.run(old_state, new_state, data) + def token_calback(self, srv: Server, tok: Token): + if self.token_callback is None: + return + self.token_callback(srv, tok) + def set_profile(self, profile_id: str) -> None: """Set the profile of the current server @@ -585,6 +602,19 @@ class EduVPN(object): eduvpn_objects: Dict[str, EduVPN] = {} +@UpdateToken +def token_callback(name: bytes, srv, tok): + name_decoded = name.decode() + if name_decoded not in eduvpn_objects: + return 0 + obj = eduvpn_objects[name_decoded] + srv_conv = get_transition_server(obj.lib, srv) + tok_conv = get_tokens(obj.lib, tok) + obj.token_callback( + srv_conv, tok_conv + ) + + @VPNStateChange def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> int: """The internal callback that is passed to the Go library diff --git a/wrappers/python/eduvpn_common/server.py b/wrappers/python/eduvpn_common/server.py index 068dc61..55eadcd 100644 --- a/wrappers/python/eduvpn_common/server.py +++ b/wrappers/python/eduvpn_common/server.py @@ -392,6 +392,19 @@ def get_locations(lib: CDLL, ptr: c_void_p) -> Optional[List[str]]: return None +def get_tokens(lib: CDLL, ptr: c_void_p) -> Optional[Token]: + if ptr: + toks = cast(ptr, POINTER(cToken)).contents + access = toks.access.decode("utf-8") + refresh = toks.refresh.decode("utf-8") + expired = toks.expired + lib.FreeTokens(ptr) + return Token( + access, refresh, expired + ) + return None + + def get_config(lib: CDLL, ptr: c_void_p) -> Optional[Config]: """Get the config from the Go library as a C structure and return a Python usable structure @@ -403,7 +416,6 @@ def get_config(lib: CDLL, ptr: c_void_p) -> Optional[Config]: :return: The configuration if there is any :rtype: Optional[Config] """ - # TODO: FREE if ptr: config = cast(ptr, POINTER(cConfig)).contents cfg = config.config.decode("utf-8") diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py index 4bc5a85..32a7a00 100644 --- a/wrappers/python/eduvpn_common/types.py +++ b/wrappers/python/eduvpn_common/types.py @@ -195,6 +195,7 @@ class DataError(Structure): # The type for a Go state change callback VPNStateChange = CFUNCTYPE(c_int, c_char_p, c_int, c_int, c_void_p) ReadRxBytes = CFUNCTYPE(c_ulonglong) +UpdateToken = CFUNCTYPE(None, c_char_p, c_void_p, c_void_p) def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]: |
