diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2023-04-13 15:25:42 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2023-09-25 09:43:37 +0200 |
| commit | f6f4396a81ce662eb5fb50c4f9fef92ffaadb333 (patch) | |
| tree | 3e6b1c3ecc049c1aa377c8ca813f00b96437d2f3 | |
| parent | 5610dca6c5e391ee62874c4d6cb25072d9c3c1d9 (diff) | |
All: Implement a token handler
This implements a token handler for OAuth tokens. Clients can use the SetTokenHandler
function in exports to set a token handler.
It needs two arguments, a getter and a setter. The getter is a callback with three arguments:
- The server to get the tokens for, in types.server.current as JSON
- The output buffer
- The output buffer maximum length
The tokens should be written to the output buffer with maximum
length. The type should be types.server.Tokens and be marshalled as
JSON. If no tokens are available, leave the output buffer intact
The token setter is a callback with two arguments:
- The server for which to set the tokens for, in types.server.Current as JSON
- The tokens, defined in types.server.Tokens as JSON
Breaking changes:
- No more tokens as arguments, was already deprecated in previous commits
- Tokens are no longer returned in types.server.Configuration
| -rw-r--r-- | client/client.go | 108 | ||||
| -rw-r--r-- | exports/exports.go | 71 | ||||
| -rw-r--r-- | internal/server/server.go | 21 | ||||
| -rw-r--r-- | types/server/server.go | 3 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/loader.py | 3 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/main.py | 56 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/types.py | 7 |
7 files changed, 228 insertions, 41 deletions
diff --git a/client/client.go b/client/client.go index 813f6dc..28080b0 100644 --- a/client/client.go +++ b/client/client.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/eduvpn/eduvpn-common/internal/config" "github.com/eduvpn/eduvpn-common/internal/discovery" @@ -109,7 +110,59 @@ type Client struct { Debug bool `json:"-"` // The Failover monitor for the current VPN connection - Failover *failover.DroppedConMon + Failover *failover.DroppedConMon `json:"-"` + + // TokenSetter sets the tokens in the client + TokenSetter func(srv srvtypes.Current, tok srvtypes.Tokens) `json:"-"` + + // TokenGetter gets the tokens from the client + TokenGetter func(srv srvtypes.Current) *srvtypes.Tokens `json:"-"` +} + +func (c *Client) updateTokens(srv server.Server) error { + if c.TokenGetter == nil { + return errors.New("no tokken getter defined") + } + pSrv, err := c.pubCurrentServer(srv) + if err != nil { + return err + } + // shouldn't happen + if pSrv == nil { + return errors.New("public server is nil when getting tokens") + } + tokens := c.TokenGetter(*pSrv) + if tokens == nil { + return errors.New("client returned nil for tokens") + } + + server.UpdateTokens(srv, oauth.Token{ + Access: tokens.Access, + Refresh: tokens.Refresh, + ExpiredTimestamp: time.Unix(tokens.Expires, 0), + }) + + return nil +} + +func (c *Client) forwardTokens(srv server.Server) error { + if c.TokenSetter == nil { + return errors.New("no token setter defined") + } + pSrv, err := c.pubCurrentServer(srv) + if err != nil { + return err + } + if pSrv == nil { + return errors.New("public server is nil when updating tokens") + } + o := srv.OAuth() + if o == nil { + return errors.New("oauth was nil when forwarding tokens") + } + t := o.Token() + c.TokenSetter(*pSrv, t.Public()) + return nil } // New creates a new client with the following parameters: @@ -329,7 +382,18 @@ func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool) // oauth // TODO: This should be ck.Context() // But needsrelogin needs a rewrite to support this properly + + // first make sure we get the most up to date tokens from the client + err := c.updateTokens(srv) + if err != nil { + log.Logger.Debugf("failed to get tokens from client: %v", err) + } if server.NeedsRelogin(context.Background(), srv) || forceauth { + // mark organizations as expired if the server is a secure internet server + b, berr := srv.Base() + if berr == nil && b.Type == srvtypes.TypeSecureInternet { + c.Discovery.MarkOrganizationsExpired() + } err := c.loginCallback(ck, srv) if err != nil { return err @@ -421,6 +485,10 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. case srvtypes.TypeSecureInternet: dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier) if err != nil { + // We mark the organizations as expired because we got an error + // Note that in the docs it states that it only should happen when the Org ID doesn't exist + // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found + c.Discovery.MarkOrganizationsExpired() return err } srv, err = c.Servers.AddSecureInternet(ck.Context(), dOrg, dSrv) @@ -442,7 +510,15 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. } // callbacks - return c.callbacks(ck, srv, false) + err = c.callbacks(ck, srv, false) + if err != nil { + return err + } + terr := c.forwardTokens(srv) + if terr != nil { + log.Logger.Debugf("failed to forward tokens after adding: %v", terr) + } + return nil } func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool) (cfg *srvtypes.Configuration, err error) { @@ -529,6 +605,14 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. cfg, err = c.config(ck, srv, pTCP, true) } + // tokens might be updated, forward them + defer func() { + terr := c.forwardTokens(srv) + if terr != nil { + log.Logger.Debugf("failed to forward tokens after get config: %v", terr) + } + }() + // still an error, return nil with the error if err != nil { return nil, err @@ -700,15 +784,18 @@ func (c *Client) Cleanup(ck *cookie.Cookie) (err error) { if err != nil { return err } - // TODO: Support cookie context here - // if server.NeedsRelogin(context.Background(), srv) { - // // TODO: ask client for tokens - // } + err = c.updateTokens(srv) + if err != nil { + log.Logger.Debugf("failed to update tokens for disconnect: %v", err) + } err = server.Disconnect(ck.Context(), srv) if err != nil { return err } - // TODO: Set tokens with callback + err = c.forwardTokens(srv) + if err != nil { + log.Logger.Debugf("failed to forward tokens after disconnect: %v", err) + } return nil } @@ -740,6 +827,13 @@ func (c *Client) RenewSession(ck *cookie.Cookie) (err error) { c.FSM.GoTransition(StateLoadingServer) c.FSM.GoTransition(StateChosenServer) } + // update tokens in the end + defer func() { + terr := c.forwardTokens(srv) + if terr != nil { + log.Logger.Debugf("failed to forward tokens after renew: %v", terr) + } + }() // TODO: Maybe this can be deleted because we force auth now server.MarkTokensForRenew(srv) // run the callbacks by forcing auth diff --git a/exports/exports.go b/exports/exports.go index f33fedd..6fc7f33 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -9,6 +9,9 @@ typedef long long int (*ReadRxBytes)(); typedef int (*StateCB)(int oldstate, int newstate, void* data); +typedef const char* (*TokenGetter)(const char* server, char* out, size_t len); +typedef void (*TokenSetter)(const char* server, const char* tokens); + static long long int get_read_rx_bytes(ReadRxBytes read) { return read(); @@ -17,10 +20,19 @@ static int call_callback(StateCB callback, int oldstate, int newstate, void* dat { return callback(oldstate, newstate, data); } +static void call_token_getter(TokenGetter getter, const char* server, char* out, size_t len) +{ + getter(server, out, len); +} +static void call_token_setter(TokenSetter setter, const char* server, const char* tokens) +{ + setter(server, tokens); +} */ import "C" import ( + "bytes" "context" "encoding/json" "runtime/cgo" @@ -29,6 +41,7 @@ import ( "github.com/go-errors/errors" "github.com/eduvpn/eduvpn-common/client" + "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/types/cookie" srvtypes "github.com/eduvpn/eduvpn-common/types/server" ) @@ -382,6 +395,64 @@ func getCookie(c C.uintptr_t) (*cookie.Cookie, error) { return v, nil } +//export SetTokenHandler +func SetTokenHandler(getter C.TokenGetter, setter C.TokenSetter) *C.char { + state, stateErr := getVPNState() + if stateErr != nil { + return getCError(stateErr) + } + state.TokenSetter = func(c srvtypes.Current, t srvtypes.Tokens) { + cJSON, err := getReturnData(c) + if err != nil { + log.Logger.Warningf("failed to get current server for setting tokens in exports: %v", err) + return + } + tJSON, err := getReturnData(t) + if err != nil { + log.Logger.Warningf("failed to get tokens for setting tokens in exports: %v", err) + return + } + c1 := C.CString(cJSON) + c2 := C.CString(tJSON) + C.call_token_setter(setter, c1, c2) + FreeString(c1) + FreeString(c2) + } + + state.TokenGetter = func(c srvtypes.Current) *srvtypes.Tokens { + cJSON, err := getReturnData(c) + if err != nil { + log.Logger.Warningf("failed to get current server for getting tokens in exports: %v", err) + return nil + } + c1 := C.CString(cJSON) + // create an output buffer with size 2048 + // In my testing tokens seem to be ~1033 bytes marshalled as JSON + d := make([]byte, 2048) + + C.call_token_getter(getter, c1, (*C.char)(unsafe.Pointer(&d[0])), C.size_t(len(d))) + FreeString(c1) + + // get null pointer index as unmarshalling wants it without + null := bytes.IndexByte(d, 0) + if null < 0 { + log.Logger.Warningf("output buffer is not NULL terminated") + return nil + } + + var gotT srvtypes.Tokens + err = json.Unmarshal(d[:null], &gotT) + if err != nil { + log.Logger.Warningf("failed to get json data for getting tokens in exports: %v", err) + return nil + } + return &gotT + } + + + return nil +} + //export CookieNew func CookieNew() C.uintptr_t { c := cookie.NewWithContext(context.Background()) diff --git a/internal/server/server.go b/internal/server/server.go index e7229c5..c34158a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -108,9 +108,6 @@ type ConfigData struct { // The type of configuration Type string - - // The tokens - Tokens oauth.Token } // Public gets the public data from the types package @@ -120,7 +117,6 @@ func (c *ConfigData) Public(dg bool) srvtypes.Configuration { VPNConfig: c.Config, Protocol: protocol.New(c.Type), DefaultGateway: dg, - Tokens: c.Tokens.Public(), } } @@ -154,13 +150,7 @@ func wireguardGetConfig(ctx context.Context, srv Server, preferTCP bool, openVPN cfg = wireguard.ConfigAddKey(cfg, key) } - t := oauth.Token{} - o := srv.OAuth() - if o != nil { - t = o.Token() - } - - return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil + return &ConfigData{Config: cfg, Type: proto}, nil } func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigData, error) { @@ -178,14 +168,7 @@ func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigD b.StartTime = time.Now() b.EndTime = exp - t := oauth.Token{} - - o := srv.OAuth() - if o != nil { - t = o.Token() - } - - return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil + return &ConfigData{Config: cfg, Type: "openvpn"}, nil } func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bool, error) { diff --git a/types/server/server.go b/types/server/server.go index 1578f29..aea496d 100644 --- a/types/server/server.go +++ b/types/server/server.go @@ -130,9 +130,6 @@ type Configuration struct { Protocol protocol.Protocol `json:"protocol"` // DefaultGateway is a boolean that indicates whether or not this configuration should be configured as a default gateway DefaultGateway bool `json:"default_gateway"` - // Tokens is the updated tokens that we get back from the VPN configuration - // They should be used by the client to save them in e.g. the keyring - Tokens Tokens `json:"tokens"` } // Current is the struct that defines the current server diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py index 961b569..1a172af 100644 --- a/wrappers/python/eduvpn_common/loader.py +++ b/wrappers/python/eduvpn_common/loader.py @@ -4,7 +4,7 @@ from collections import defaultdict from ctypes import CDLL, c_char_p, c_int, c_void_p, cdll from eduvpn_common import __version__ -from eduvpn_common.types import BoolError, DataError, ReadRxBytes, VPNStateChange +from eduvpn_common.types import BoolError, DataError, ReadRxBytes, TokenGetter, TokenSetter, VPNStateChange def load_lib() -> CDLL: @@ -88,6 +88,7 @@ def initialize_functions(lib: CDLL) -> None: c_int, ], c_void_p lib.RenewSession.argtypes, lib.RenewSession.restype = [c_int], c_void_p + lib.SetTokenHandler.argtypes, lib.SetTokenHandler.restype = [TokenGetter, TokenSetter], c_void_p lib.Cleanup.argtypes, lib.Cleanup.restype = [c_int], c_void_p lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p], c_void_p lib.CookieNew.argtypes, lib.CookieNew.restype = [], c_int diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index b10e641..5d08ba9 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -1,8 +1,9 @@ +import ctypes from enum import IntEnum from typing import Any, Callable, Iterator, Optional from eduvpn_common.loader import initialize_functions, load_lib -from eduvpn_common.types import ReadRxBytes, VPNStateChange, decode_res, encode_args +from eduvpn_common.types import ReadRxBytes, TokenGetter, TokenSetter, VPNStateChange, decode_res, encode_args class WrappedError(Exception): @@ -56,6 +57,9 @@ class EduVPN(object): self.version = version self.config_directory = config_directory self.jar = Jar(lambda x: self.go_function(self.lib.CookieCancel, x)) + self.callback = None + self.token_setter = None + self.token_getter = None # Load the library self.lib = load_lib() @@ -88,8 +92,8 @@ class EduVPN(object): This removes the object from internal bookkeeping and saves the configuration """ self.go_function(self.lib.Deregister) - global callback_object - callback_object = None + global global_object + global_object = None def register(self, handler: Optional[Callable] = None, debug: bool = False) -> None: """Register the Go shared library. @@ -99,10 +103,11 @@ class EduVPN(object): :param debug: bool: (Default value = False): Whether or not we want to enable debug logging """ - global callback_object - if callback_object is not None: + global global_object + if global_object is not None: raise Exception("Already registered") - callback_object = handler + self.callback = handler + global_object = self register_err = self.go_function( self.lib.Register, self.name, @@ -244,6 +249,14 @@ class EduVPN(object): if location_err: forwardError(location_err) + def set_token_handler(self, getter: Callable, setter: Callable) -> None: + self.token_setter = setter + self.token_getter = getter + handler_err = self.go_function(self.lib.SetTokenHandler, token_getter, token_setter) + + if handler_err: + forwardError(handler_err) + def cookie_reply(self, cookie: int, data: str) -> None: """Reply with the given cookie and data""" cookie_err = self.go_function(self.lib.CookieReply, cookie, data) @@ -289,8 +302,30 @@ class EduVPN(object): self.jar.cancel() -callback_object: Optional[Callable] = None +global_object: Optional[EduVPN] = None +@TokenSetter +def token_setter(server: ctypes.c_char_p, tokens: ctypes.c_char_p): + global global_object + if global_object is None: + return + if global_object.token_setter is None: + return 0 + global_object.token_setter(server.decode(), tokens.decode()) + +@TokenGetter +def token_getter(server: ctypes.c_char_p, buf: ctypes.c_char_p, size: ctypes.c_size_t): + global global_object + if global_object is None: + return + if global_object.token_getter is None: + return + got = global_object.token_getter(server.decode()) + if got is None: + return + + outbuf = ctypes.cast(buf, ctypes.POINTER(ctypes.c_char * size)) + outbuf.contents.value = got.encode("utf-8") @VPNStateChange def state_callback(old_state: int, new_state: int, data: str) -> int: @@ -302,9 +337,12 @@ def state_callback(old_state: int, new_state: int, data: str) -> int: :meta private: """ - if callback_object is None: + global global_object + if global_object is None: + return 0 + if global_object.callback is None: return 0 - handled = callback_object(old_state, new_state, data.decode("utf-8")) + handled = global_object.callback(old_state, new_state, data.decode("utf-8")) if handled: return 1 return 0 diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py index 1eba468..f83e710 100644 --- a/wrappers/python/eduvpn_common/types.py +++ b/wrappers/python/eduvpn_common/types.py @@ -1,10 +1,13 @@ from ctypes import ( CDLL, CFUNCTYPE, + POINTER, Structure, + c_char, c_char_p, c_int, c_ulonglong, + c_size_t, c_void_p, cast, ) @@ -32,8 +35,8 @@ class BoolError(Structure): # The type for a Go state change callback VPNStateChange = CFUNCTYPE(c_int, c_int, c_int, c_char_p) ReadRxBytes = CFUNCTYPE(c_ulonglong) -UpdateToken = CFUNCTYPE(None, c_char_p, c_void_p, c_void_p) - +TokenGetter = CFUNCTYPE(c_void_p, c_char_p, POINTER(c_char), c_size_t) +TokenSetter = CFUNCTYPE(c_void_p, c_char_p, c_char_p) def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]: """Encode the arguments ready to be used by the Go library |
