diff options
| -rw-r--r-- | client/client.go | 108 | ||||
| -rw-r--r-- | exports/exports.go | 71 | ||||
| -rw-r--r-- | internal/server/server.go | 21 | ||||
| -rw-r--r-- | types/server/server.go | 3 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/loader.py | 3 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/main.py | 56 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/types.py | 7 |
7 files changed, 228 insertions, 41 deletions
diff --git a/client/client.go b/client/client.go index 813f6dc..28080b0 100644 --- a/client/client.go +++ b/client/client.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/eduvpn/eduvpn-common/internal/config" "github.com/eduvpn/eduvpn-common/internal/discovery" @@ -109,7 +110,59 @@ type Client struct { Debug bool `json:"-"` // The Failover monitor for the current VPN connection - Failover *failover.DroppedConMon + Failover *failover.DroppedConMon `json:"-"` + + // TokenSetter sets the tokens in the client + TokenSetter func(srv srvtypes.Current, tok srvtypes.Tokens) `json:"-"` + + // TokenGetter gets the tokens from the client + TokenGetter func(srv srvtypes.Current) *srvtypes.Tokens `json:"-"` +} + +func (c *Client) updateTokens(srv server.Server) error { + if c.TokenGetter == nil { + return errors.New("no tokken getter defined") + } + pSrv, err := c.pubCurrentServer(srv) + if err != nil { + return err + } + // shouldn't happen + if pSrv == nil { + return errors.New("public server is nil when getting tokens") + } + tokens := c.TokenGetter(*pSrv) + if tokens == nil { + return errors.New("client returned nil for tokens") + } + + server.UpdateTokens(srv, oauth.Token{ + Access: tokens.Access, + Refresh: tokens.Refresh, + ExpiredTimestamp: time.Unix(tokens.Expires, 0), + }) + + return nil +} + +func (c *Client) forwardTokens(srv server.Server) error { + if c.TokenSetter == nil { + return errors.New("no token setter defined") + } + pSrv, err := c.pubCurrentServer(srv) + if err != nil { + return err + } + if pSrv == nil { + return errors.New("public server is nil when updating tokens") + } + o := srv.OAuth() + if o == nil { + return errors.New("oauth was nil when forwarding tokens") + } + t := o.Token() + c.TokenSetter(*pSrv, t.Public()) + return nil } // New creates a new client with the following parameters: @@ -329,7 +382,18 @@ func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool) // oauth // TODO: This should be ck.Context() // But needsrelogin needs a rewrite to support this properly + + // first make sure we get the most up to date tokens from the client + err := c.updateTokens(srv) + if err != nil { + log.Logger.Debugf("failed to get tokens from client: %v", err) + } if server.NeedsRelogin(context.Background(), srv) || forceauth { + // mark organizations as expired if the server is a secure internet server + b, berr := srv.Base() + if berr == nil && b.Type == srvtypes.TypeSecureInternet { + c.Discovery.MarkOrganizationsExpired() + } err := c.loginCallback(ck, srv) if err != nil { return err @@ -421,6 +485,10 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. case srvtypes.TypeSecureInternet: dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier) if err != nil { + // We mark the organizations as expired because we got an error + // Note that in the docs it states that it only should happen when the Org ID doesn't exist + // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found + c.Discovery.MarkOrganizationsExpired() return err } srv, err = c.Servers.AddSecureInternet(ck.Context(), dOrg, dSrv) @@ -442,7 +510,15 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. } // callbacks - return c.callbacks(ck, srv, false) + err = c.callbacks(ck, srv, false) + if err != nil { + return err + } + terr := c.forwardTokens(srv) + if terr != nil { + log.Logger.Debugf("failed to forward tokens after adding: %v", terr) + } + return nil } func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool) (cfg *srvtypes.Configuration, err error) { @@ -529,6 +605,14 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. cfg, err = c.config(ck, srv, pTCP, true) } + // tokens might be updated, forward them + defer func() { + terr := c.forwardTokens(srv) + if terr != nil { + log.Logger.Debugf("failed to forward tokens after get config: %v", terr) + } + }() + // still an error, return nil with the error if err != nil { return nil, err @@ -700,15 +784,18 @@ func (c *Client) Cleanup(ck *cookie.Cookie) (err error) { if err != nil { return err } - // TODO: Support cookie context here - // if server.NeedsRelogin(context.Background(), srv) { - // // TODO: ask client for tokens - // } + err = c.updateTokens(srv) + if err != nil { + log.Logger.Debugf("failed to update tokens for disconnect: %v", err) + } err = server.Disconnect(ck.Context(), srv) if err != nil { return err } - // TODO: Set tokens with callback + err = c.forwardTokens(srv) + if err != nil { + log.Logger.Debugf("failed to forward tokens after disconnect: %v", err) + } return nil } @@ -740,6 +827,13 @@ func (c *Client) RenewSession(ck *cookie.Cookie) (err error) { c.FSM.GoTransition(StateLoadingServer) c.FSM.GoTransition(StateChosenServer) } + // update tokens in the end + defer func() { + terr := c.forwardTokens(srv) + if terr != nil { + log.Logger.Debugf("failed to forward tokens after renew: %v", terr) + } + }() // TODO: Maybe this can be deleted because we force auth now server.MarkTokensForRenew(srv) // run the callbacks by forcing auth diff --git a/exports/exports.go b/exports/exports.go index f33fedd..6fc7f33 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -9,6 +9,9 @@ typedef long long int (*ReadRxBytes)(); typedef int (*StateCB)(int oldstate, int newstate, void* data); +typedef const char* (*TokenGetter)(const char* server, char* out, size_t len); +typedef void (*TokenSetter)(const char* server, const char* tokens); + static long long int get_read_rx_bytes(ReadRxBytes read) { return read(); @@ -17,10 +20,19 @@ static int call_callback(StateCB callback, int oldstate, int newstate, void* dat { return callback(oldstate, newstate, data); } +static void call_token_getter(TokenGetter getter, const char* server, char* out, size_t len) +{ + getter(server, out, len); +} +static void call_token_setter(TokenSetter setter, const char* server, const char* tokens) +{ + setter(server, tokens); +} */ import "C" import ( + "bytes" "context" "encoding/json" "runtime/cgo" @@ -29,6 +41,7 @@ import ( "github.com/go-errors/errors" "github.com/eduvpn/eduvpn-common/client" + "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/types/cookie" srvtypes "github.com/eduvpn/eduvpn-common/types/server" ) @@ -382,6 +395,64 @@ func getCookie(c C.uintptr_t) (*cookie.Cookie, error) { return v, nil } +//export SetTokenHandler +func SetTokenHandler(getter C.TokenGetter, setter C.TokenSetter) *C.char { + state, stateErr := getVPNState() + if stateErr != nil { + return getCError(stateErr) + } + state.TokenSetter = func(c srvtypes.Current, t srvtypes.Tokens) { + cJSON, err := getReturnData(c) + if err != nil { + log.Logger.Warningf("failed to get current server for setting tokens in exports: %v", err) + return + } + tJSON, err := getReturnData(t) + if err != nil { + log.Logger.Warningf("failed to get tokens for setting tokens in exports: %v", err) + return + } + c1 := C.CString(cJSON) + c2 := C.CString(tJSON) + C.call_token_setter(setter, c1, c2) + FreeString(c1) + FreeString(c2) + } + + state.TokenGetter = func(c srvtypes.Current) *srvtypes.Tokens { + cJSON, err := getReturnData(c) + if err != nil { + log.Logger.Warningf("failed to get current server for getting tokens in exports: %v", err) + return nil + } + c1 := C.CString(cJSON) + // create an output buffer with size 2048 + // In my testing tokens seem to be ~1033 bytes marshalled as JSON + d := make([]byte, 2048) + + C.call_token_getter(getter, c1, (*C.char)(unsafe.Pointer(&d[0])), C.size_t(len(d))) + FreeString(c1) + + // get null pointer index as unmarshalling wants it without + null := bytes.IndexByte(d, 0) + if null < 0 { + log.Logger.Warningf("output buffer is not NULL terminated") + return nil + } + + var gotT srvtypes.Tokens + err = json.Unmarshal(d[:null], &gotT) + if err != nil { + log.Logger.Warningf("failed to get json data for getting tokens in exports: %v", err) + return nil + } + return &gotT + } + + + return nil +} + //export CookieNew func CookieNew() C.uintptr_t { c := cookie.NewWithContext(context.Background()) diff --git a/internal/server/server.go b/internal/server/server.go index e7229c5..c34158a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -108,9 +108,6 @@ type ConfigData struct { // The type of configuration Type string - - // The tokens - Tokens oauth.Token } // Public gets the public data from the types package @@ -120,7 +117,6 @@ func (c *ConfigData) Public(dg bool) srvtypes.Configuration { VPNConfig: c.Config, Protocol: protocol.New(c.Type), DefaultGateway: dg, - Tokens: c.Tokens.Public(), } } @@ -154,13 +150,7 @@ func wireguardGetConfig(ctx context.Context, srv Server, preferTCP bool, openVPN cfg = wireguard.ConfigAddKey(cfg, key) } - t := oauth.Token{} - o := srv.OAuth() - if o != nil { - t = o.Token() - } - - return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil + return &ConfigData{Config: cfg, Type: proto}, nil } func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigData, error) { @@ -178,14 +168,7 @@ func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigD b.StartTime = time.Now() b.EndTime = exp - t := oauth.Token{} - - o := srv.OAuth() - if o != nil { - t = o.Token() - } - - return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil + return &ConfigData{Config: cfg, Type: "openvpn"}, nil } func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bool, error) { diff --git a/types/server/server.go b/types/server/server.go index 1578f29..aea496d 100644 --- a/types/server/server.go +++ b/types/server/server.go @@ -130,9 +130,6 @@ type Configuration struct { Protocol protocol.Protocol `json:"protocol"` // DefaultGateway is a boolean that indicates whether or not this configuration should be configured as a default gateway DefaultGateway bool `json:"default_gateway"` - // Tokens is the updated tokens that we get back from the VPN configuration - // They should be used by the client to save them in e.g. the keyring - Tokens Tokens `json:"tokens"` } // Current is the struct that defines the current server diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py index 961b569..1a172af 100644 --- a/wrappers/python/eduvpn_common/loader.py +++ b/wrappers/python/eduvpn_common/loader.py @@ -4,7 +4,7 @@ from collections import defaultdict from ctypes import CDLL, c_char_p, c_int, c_void_p, cdll from eduvpn_common import __version__ -from eduvpn_common.types import BoolError, DataError, ReadRxBytes, VPNStateChange +from eduvpn_common.types import BoolError, DataError, ReadRxBytes, TokenGetter, TokenSetter, VPNStateChange def load_lib() -> CDLL: @@ -88,6 +88,7 @@ def initialize_functions(lib: CDLL) -> None: c_int, ], c_void_p lib.RenewSession.argtypes, lib.RenewSession.restype = [c_int], c_void_p + lib.SetTokenHandler.argtypes, lib.SetTokenHandler.restype = [TokenGetter, TokenSetter], c_void_p lib.Cleanup.argtypes, lib.Cleanup.restype = [c_int], c_void_p lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p], c_void_p lib.CookieNew.argtypes, lib.CookieNew.restype = [], c_int diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index b10e641..5d08ba9 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -1,8 +1,9 @@ +import ctypes from enum import IntEnum from typing import Any, Callable, Iterator, Optional from eduvpn_common.loader import initialize_functions, load_lib -from eduvpn_common.types import ReadRxBytes, VPNStateChange, decode_res, encode_args +from eduvpn_common.types import ReadRxBytes, TokenGetter, TokenSetter, VPNStateChange, decode_res, encode_args class WrappedError(Exception): @@ -56,6 +57,9 @@ class EduVPN(object): self.version = version self.config_directory = config_directory self.jar = Jar(lambda x: self.go_function(self.lib.CookieCancel, x)) + self.callback = None + self.token_setter = None + self.token_getter = None # Load the library self.lib = load_lib() @@ -88,8 +92,8 @@ class EduVPN(object): This removes the object from internal bookkeeping and saves the configuration """ self.go_function(self.lib.Deregister) - global callback_object - callback_object = None + global global_object + global_object = None def register(self, handler: Optional[Callable] = None, debug: bool = False) -> None: """Register the Go shared library. @@ -99,10 +103,11 @@ class EduVPN(object): :param debug: bool: (Default value = False): Whether or not we want to enable debug logging """ - global callback_object - if callback_object is not None: + global global_object + if global_object is not None: raise Exception("Already registered") - callback_object = handler + self.callback = handler + global_object = self register_err = self.go_function( self.lib.Register, self.name, @@ -244,6 +249,14 @@ class EduVPN(object): if location_err: forwardError(location_err) + def set_token_handler(self, getter: Callable, setter: Callable) -> None: + self.token_setter = setter + self.token_getter = getter + handler_err = self.go_function(self.lib.SetTokenHandler, token_getter, token_setter) + + if handler_err: + forwardError(handler_err) + def cookie_reply(self, cookie: int, data: str) -> None: """Reply with the given cookie and data""" cookie_err = self.go_function(self.lib.CookieReply, cookie, data) @@ -289,8 +302,30 @@ class EduVPN(object): self.jar.cancel() -callback_object: Optional[Callable] = None +global_object: Optional[EduVPN] = None +@TokenSetter +def token_setter(server: ctypes.c_char_p, tokens: ctypes.c_char_p): + global global_object + if global_object is None: + return + if global_object.token_setter is None: + return 0 + global_object.token_setter(server.decode(), tokens.decode()) + +@TokenGetter +def token_getter(server: ctypes.c_char_p, buf: ctypes.c_char_p, size: ctypes.c_size_t): + global global_object + if global_object is None: + return + if global_object.token_getter is None: + return + got = global_object.token_getter(server.decode()) + if got is None: + return + + outbuf = ctypes.cast(buf, ctypes.POINTER(ctypes.c_char * size)) + outbuf.contents.value = got.encode("utf-8") @VPNStateChange def state_callback(old_state: int, new_state: int, data: str) -> int: @@ -302,9 +337,12 @@ def state_callback(old_state: int, new_state: int, data: str) -> int: :meta private: """ - if callback_object is None: + global global_object + if global_object is None: + return 0 + if global_object.callback is None: return 0 - handled = callback_object(old_state, new_state, data.decode("utf-8")) + handled = global_object.callback(old_state, new_state, data.decode("utf-8")) if handled: return 1 return 0 diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py index 1eba468..f83e710 100644 --- a/wrappers/python/eduvpn_common/types.py +++ b/wrappers/python/eduvpn_common/types.py @@ -1,10 +1,13 @@ from ctypes import ( CDLL, CFUNCTYPE, + POINTER, Structure, + c_char, c_char_p, c_int, c_ulonglong, + c_size_t, c_void_p, cast, ) @@ -32,8 +35,8 @@ class BoolError(Structure): # The type for a Go state change callback VPNStateChange = CFUNCTYPE(c_int, c_int, c_int, c_char_p) ReadRxBytes = CFUNCTYPE(c_ulonglong) -UpdateToken = CFUNCTYPE(None, c_char_p, c_void_p, c_void_p) - +TokenGetter = CFUNCTYPE(c_void_p, c_char_p, POINTER(c_char), c_size_t) +TokenSetter = CFUNCTYPE(c_void_p, c_char_p, c_char_p) def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]: """Encode the arguments ready to be used by the Go library |
