diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-12-20 15:43:55 +0100 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-12-21 18:28:50 +0100 |
| commit | 12838c19514459974cf0a71c42f1248b1cb9419c (patch) | |
| tree | a4254d20bb7b0ef49a2fa6c12753eb4c5acb64d1 | |
| parent | 6981666c6d8f639a1ff9c09a3bc08769e19928af (diff) | |
Exports + OAuth + Server: Forward tokens to getting a config
| -rw-r--r-- | client/server.go | 107 | ||||
| -rw-r--r-- | exports/exports.go | 86 | ||||
| -rw-r--r-- | internal/oauth/oauth.go | 24 | ||||
| -rw-r--r-- | internal/oauth/token.go | 74 | ||||
| -rw-r--r-- | internal/server/server.go | 53 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/loader.py | 12 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/main.py | 43 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/server.py | 63 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/types.py | 54 |
9 files changed, 359 insertions, 157 deletions
diff --git a/client/server.go b/client/server.go index 6802a47..0bb37a8 100644 --- a/client/server.go +++ b/client/server.go @@ -9,12 +9,14 @@ import ( "github.com/go-errors/errors" ) +type ConfigData = server.ConfigData + // getConfigAuth gets a config with authorization and authentication. // It also asks for a profile if no valid profile is found. -func (c *Client) getConfigAuth(srv server.Server, preferTCP bool) (string, string, error) { - err := c.ensureLogin(srv) +func (c *Client) getConfigAuth(srv server.Server, preferTCP bool, t oauth.Token) (*ConfigData, error) { + err := c.ensureLogin(srv, t) if err != nil { - return "", "", err + return nil, err } // TODO(jwijenbergh): Should we check if it returns false? @@ -22,13 +24,13 @@ func (c *Client) getConfigAuth(srv server.Server, preferTCP bool) (string, strin ok, err := server.HasValidProfile(srv, c.SupportsWireguard) if err != nil { - return "", "", err + return nil, err } // No valid profile, ask for one if !ok { if err = c.askProfile(srv); err != nil { - return "", "", err + return nil, err } } @@ -38,29 +40,28 @@ func (c *Client) getConfigAuth(srv server.Server, preferTCP bool) (string, strin // retryConfigAuth retries the getConfigAuth function if the tokens are invalid. // If OAuth is cancelled, it makes sure that we only forward the error as additional info. -func (c *Client) retryConfigAuth(srv server.Server, preferTCP bool) (string, string, error) { - cfg, cfgType, err := c.getConfigAuth(srv, preferTCP) +func (c *Client) retryConfigAuth(srv server.Server, preferTCP bool, t oauth.Token) (*ConfigData, error) { + cfg, err := c.getConfigAuth(srv, preferTCP, t) if err == nil { - return cfg, cfgType, nil - } - if err != nil { - // Only retry if the error is that the tokens are invalid - tErr := &oauth.TokensInvalidError{} - if errors.As(err, &tErr) { - cfg, cfgType, err = c.getConfigAuth(srv, preferTCP) - if err == nil { - return cfg, cfgType, nil - } + return cfg, nil + } + // Only retry if the error is that the tokens are invalid + tErr := &oauth.TokensInvalidError{} + if errors.As(err, &tErr) { + // TODO: Is passing empty tokens correct here? + cfg, err = c.getConfigAuth(srv, preferTCP, oauth.Token{}) + if err == nil { + return cfg, nil } - c.goBackInternal() } - return "", "", err + c.goBackInternal() + return nil, err } // getConfig gets an OpenVPN/WireGuard configuration by contacting the server, moving the FSM towards the DISCONNECTED state and then saving the local configuration file. -func (c *Client) getConfig(srv server.Server, preferTCP bool) (string, string, error) { +func (c *Client) getConfig(srv server.Server, preferTCP bool, t oauth.Token) (*ConfigData, error) { if c.InFSMState(StateDeregistered) { - return "", "", errors.Errorf("getConfig attempt in '%v'", StateDeregistered) + return nil, errors.Errorf("getConfig attempt in '%v'", StateDeregistered) } // Refresh the server endpoints @@ -70,14 +71,14 @@ func (c *Client) getConfig(srv server.Server, preferTCP bool) (string, string, e c.Logger.Warningf("failed to refresh server endpoints: %v", err) } - cfg, cfgType, err := c.retryConfigAuth(srv, preferTCP) + cfg, err := c.retryConfigAuth(srv, preferTCP, t) if err != nil { - return "", "", err + return nil, err } srv1, err := c.Servers.GetCurrentServer() if err != nil { - return "", "", err + return nil, err } // Signal the server display info @@ -91,7 +92,7 @@ func (c *Client) getConfig(srv server.Server, preferTCP bool) (string, string, e err.Error(), err.(*errors.Error).ErrorStack()) } - return cfg, cfgType, nil + return cfg, nil } // SetSecureLocation sets the location for the current secure location server. countryCode is the secure location to be chosen. @@ -227,7 +228,7 @@ func (c *Client) AddInstituteServer(url string) (srv server.Server, err error) { c.FSM.GoTransition(StateChosenServer) // Authorize it - if err = c.ensureLogin(srv); err != nil { + if err = c.ensureLogin(srv, oauth.Token{}); err != nil { // Removing is best effort _ = c.RemoveInstituteAccess(url) return nil, err @@ -285,7 +286,7 @@ func (c *Client) AddSecureInternetHomeServer(orgID string) (srv server.Server, e c.FSM.GoTransition(StateChosenServer) // Authorize it - if err = c.ensureLogin(srv); err != nil { + if err = c.ensureLogin(srv, oauth.Token{}); err != nil { // Removing is best effort _ = c.RemoveSecureInternet() return nil, err @@ -331,7 +332,7 @@ func (c *Client) AddCustomServer(url string) (srv server.Server, err error) { c.FSM.GoTransition(StateChosenServer) // Authorize it - if err = c.ensureLogin(srv); err != nil { + if err = c.ensureLogin(srv, oauth.Token{}); err != nil { // removing is best effort _ = c.RemoveCustomServer(url) return nil, err @@ -344,7 +345,7 @@ func (c *Client) AddCustomServer(url string) (srv server.Server, err error) { // GetConfigInstituteAccess gets a configuration for an Institute Access Server. // It ensures that the Institute Access Server exists by creating or using an existing one with the url. // `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool) (cfg string, cfgType string, err error) { +func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool, t oauth.Token) (cfg *ConfigData, err error) { defer func() { if err != nil { c.logError(err) @@ -353,7 +354,7 @@ func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool) (cfg strin // Not supported with Let's Connect! if c.isLetsConnect() { - return "", "", errors.Errorf("discovery with Let's Connect is not supported") + return nil, errors.Errorf("discovery with Let's Connect is not supported") } c.FSM.GoTransition(StateLoadingServer) @@ -362,28 +363,28 @@ func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool) (cfg strin var srv *server.InstituteAccessServer if srv, err = c.Servers.GetInstituteAccess(url); err != nil { c.goBackInternal() - return "", "", err + return nil, err } // Set the server as the current if err = c.Servers.SetInstituteAccess(srv); err != nil { - return "", "", err + return nil, err } // The server has now been chosen c.FSM.GoTransition(StateChosenServer) - if cfg, cfgType, err = c.getConfig(srv, preferTCP); err != nil { + if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { c.goBackInternal() } - return cfg, cfgType, err + return cfg, err } // GetConfigSecureInternet gets a configuration for a Secure Internet Server. // It ensures that the Secure Internet Server exists by creating or using an existing one with the orgID. // `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool) (cfg string, cfgType string, err error) { +func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool, t oauth.Token) (cfg *ConfigData, err error) { defer func() { if err != nil { c.logError(err) @@ -392,7 +393,7 @@ func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool) (cfg stri // Not supported with Let's Connect! if c.isLetsConnect() { - return "", "", errors.Errorf("discovery with Let's Connect is not supported") + return nil, errors.Errorf("discovery with Let's Connect is not supported") } c.FSM.GoTransition(StateLoadingServer) @@ -401,27 +402,27 @@ func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool) (cfg stri var srv *server.SecureInternetHomeServer if srv, err = c.Servers.GetSecureInternetHomeServer(); err != nil { c.goBackInternal() - return "", "", err + return nil, err } // Set the server as the current if err = c.Servers.SetSecureInternet(srv); err != nil { - return "", "", err + return nil, err } c.FSM.GoTransition(StateChosenServer) - if cfg, cfgType, err = c.getConfig(srv, preferTCP); err != nil { + if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { c.goBackInternal() } - return cfg, cfgType, err + return cfg, err } // GetConfigCustomServer gets a configuration for a Custom Server. // It ensures that the Custom Server exists by creating or using an existing one with the url. // `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (c *Client) GetConfigCustomServer(url string, preferTCP bool) (cfg string, cfgType string, err error) { +func (c *Client) GetConfigCustomServer(url string, preferTCP bool, t oauth.Token) (cfg *ConfigData, err error) { defer func() { if err != nil { c.logError(err) @@ -429,7 +430,7 @@ func (c *Client) GetConfigCustomServer(url string, preferTCP bool) (cfg string, }() if url, err = util.EnsureValidURL(url); err != nil { - return "", "", err + return nil, err } c.FSM.GoTransition(StateLoadingServer) @@ -438,21 +439,21 @@ func (c *Client) GetConfigCustomServer(url string, preferTCP bool) (cfg string, var srv *server.InstituteAccessServer if srv, err = c.Servers.GetCustomServer(url); err != nil { c.goBackInternal() - return "", "", err + return nil, err } // Set the server as the current if err = c.Servers.SetCustomServer(srv); err != nil { - return "", "", err + return nil, err } c.FSM.GoTransition(StateChosenServer) - if cfg, cfgType, err = c.getConfig(srv, preferTCP); err != nil { + if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { c.goBackInternal() } - return cfg, cfgType, err + return cfg, err } // askSecureLocation asks the user to choose a Secure Internet location by moving the FSM to the STATE_ASK_LOCATION state. @@ -512,7 +513,7 @@ func (c *Client) RenewSession() (err error) { } server.MarkTokensForRenew(srv) - return c.ensureLogin(srv) + return c.ensureLogin(srv, oauth.Token{}) } // ShouldRenewButton returns true if the renew button should be shown @@ -536,7 +537,7 @@ func (c *Client) ShouldRenewButton() bool { // ensureLogin logs the user back in if needed. // It runs the FSM transitions to ask for user input. -func (c *Client) ensureLogin(srv server.Server) (err error) { +func (c *Client) ensureLogin(srv server.Server, ct oauth.Token) (err error) { // Relogin with oauth // This moves the state to authorized if !server.NeedsRelogin(srv) { @@ -545,6 +546,16 @@ func (c *Client) ensureLogin(srv server.Server) (err error) { return nil } + // Try again but update the tokens using the client provided tokens + server.UpdateTokens(srv, ct) + if !server.NeedsRelogin(srv) { + // OAuth was valid, ensure we are in the authorized state + c.FSM.GoTransition(StateAuthorized) + return nil + } + + // Tokens are not valid or the client gave an error when updating tokens + // Otherwise, do the OAuth exchange var url string if url, err = server.OAuthURL(srv, c.Name); err != nil { return err diff --git a/exports/exports.go b/exports/exports.go index e374661..904fbec 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -5,6 +5,18 @@ package main #include "error.h" typedef long long int (*ReadRxBytes)(); +typedef struct token { + const char* access; + const char* refresh; + unsigned long long int expired; +} token; + +typedef struct configData { + const char* config; + const char* config_type; + token* tokens; +} configData; + typedef int (*PythonCB)(const char* name, int oldstate, int newstate, void* data); static long long int get_read_rx_bytes(ReadRxBytes read) @@ -20,8 +32,10 @@ import "C" import ( "unsafe" + "time" "github.com/eduvpn/eduvpn-common/internal/log" + "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/go-errors/errors" "github.com/eduvpn/eduvpn-common/client" @@ -249,20 +263,56 @@ func RemoveCustomServer(name *C.char, url *C.char) *C.error { return getError(removeErr) } +func cToken(t oauth.Token) *C.token { + cTok := (*C.token)(C.malloc(C.size_t(unsafe.Sizeof(C.token{})))) + cTok.access = C.CString(t.Access) + cTok.refresh = C.CString(t.Refresh) + cTok.expired = C.ulonglong(t.ExpiredTimestamp.Unix()) + return cTok +} + +func cConfig(config *client.ConfigData) *C.configData { + // No config so return nil pointer + if config == nil { + return nil + } + cConf := (*C.configData)(C.malloc(C.size_t(unsafe.Sizeof(C.configData{})))) + cConf.config = C.CString(config.Config) + cConf.config_type = C.CString(config.Type) + cConf.tokens = cToken(config.Tokens) + return cConf +} + +//export FreeConfig +func FreeConfig(config *C.configData) { + C.free(unsafe.Pointer(config.config)) + C.free(unsafe.Pointer(config.config_type)) + C.free(unsafe.Pointer(config.tokens.access)) + C.free(unsafe.Pointer(config.tokens.refresh)) + C.free(unsafe.Pointer(config.tokens)) + C.free(unsafe.Pointer(config)) +} + //export GetConfigSecureInternet func GetConfigSecureInternet( name *C.char, orgID *C.char, preferTCP C.int, -) (*C.char, *C.char, *C.error) { + prevTokens C.token, +) (*C.configData, *C.error) { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return nil, nil, getError(stateErr) + return nil, getError(stateErr) } preferTCPBool := preferTCP == 1 - config, configType, configErr := state.GetConfigSecureInternet(C.GoString(orgID), preferTCPBool) - return C.CString(config), C.CString(configType), getError(configErr) + t := oauth.Token{ + Access: C.GoString(prevTokens.access), + Refresh: C.GoString(prevTokens.refresh), + ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0), + } + cfg, configErr := state.GetConfigSecureInternet(C.GoString(orgID), preferTCPBool, t) + return cConfig(cfg), getError(configErr) } //export GetConfigInstituteAccess @@ -270,15 +320,21 @@ func GetConfigInstituteAccess( name *C.char, url *C.char, preferTCP C.int, -) (*C.char, *C.char, *C.error) { + prevTokens C.token, +) (*C.configData, *C.error) { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return nil, nil, getError(stateErr) + return nil, getError(stateErr) } preferTCPBool := preferTCP == 1 - config, configType, configErr := state.GetConfigInstituteAccess(C.GoString(url), preferTCPBool) - return C.CString(config), C.CString(configType), getError(configErr) + t := oauth.Token{ + Access: C.GoString(prevTokens.access), + Refresh: C.GoString(prevTokens.refresh), + ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0), + } + cfg, configErr := state.GetConfigInstituteAccess(C.GoString(url), preferTCPBool, t) + return cConfig(cfg), getError(configErr) } //export GetConfigCustomServer @@ -286,15 +342,21 @@ func GetConfigCustomServer( name *C.char, url *C.char, preferTCP C.int, -) (*C.char, *C.char, *C.error) { + prevTokens C.token, +) (*C.configData, *C.error) { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { - return nil, nil, getError(stateErr) + return nil, getError(stateErr) } preferTCPBool := preferTCP == 1 - config, configType, configErr := state.GetConfigCustomServer(C.GoString(url), preferTCPBool) - return C.CString(config), C.CString(configType), getError(configErr) + t := oauth.Token{ + Access: C.GoString(prevTokens.access), + Refresh: C.GoString(prevTokens.refresh), + ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0), + } + cfg, configErr := state.GetConfigCustomServer(C.GoString(url), preferTCPBool, t) + return cConfig(cfg), getError(configErr) } //export SetProfileID diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index ce86337..6d21c82 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -196,10 +196,19 @@ func (oauth *OAuth) SetTokenExpired() { // SetTokenRenew sets the tokens for renewal by completely clearing the structure. func (oauth *OAuth) SetTokenRenew() { if oauth.token != nil { - oauth.token.Clear() + oauth.token.Update(Token{}) } } +func (oauth *OAuth) Token() Token { + t := Token{} + if oauth.token != nil { + t = oauth.token.Get() + } + + return t +} + // tokensWithAuthCode gets the access and refresh tokens using the authorization code // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 @@ -239,10 +248,17 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error { return errors.New("No token response after authorization code") } - oauth.token.Update(*tr, now) + oauth.token.UpdateResponse(*tr, now) return nil } +func (oauth *OAuth) UpdateTokens(t Token) { + if oauth.token == nil { + oauth.token = &tokenLock{t: &tokenRefresher{Refresher: oauth.refreshResponse}} + } + oauth.token.Update(t) +} + // refreshResponse gets the refresh token response with a refresh token // This response contains the access and refresh tokens, together with a timestamp // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 @@ -420,8 +436,8 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s return "", errors.WrapPrefix(err, "genState error", 0) } - // Fill the oauth tokens - oauth.token = &tokenLock{t: &token{Refresher: oauth.refreshResponse}} + // Re-initialize the token structure + oauth.UpdateTokens(Token{}) // Fill the struct with the necessary fields filled for the next call to getting the HTTP client oauth.session = exchangeSession{ diff --git a/internal/oauth/token.go b/internal/oauth/token.go index 855677c..4ed8f43 100644 --- a/internal/oauth/token.go +++ b/internal/oauth/token.go @@ -23,17 +23,23 @@ type TokenResponse struct { Expires int64 `json:"expires_in"` } -// token is a structure that contains our access and refresh tokens and a timestamp when they expire. -type token struct { - // Access is the access token returned by the server - access string +// The public type that can be passed to an update function +// It contains our access and refresh tokens with a timestamp +type Token struct { + // Access is the Access token returned by the server + Access string - // Refresh token is the refresh token returned by the server - refresh string + // Refresh token is the Refresh token returned by the server + Refresh string // ExpiredTimestamp is the Expires field but converted to a Go timestamp - expiredTimestamp time.Time + ExpiredTimestamp time.Time +} +// tokenRefresher is a structure that contains our access and refresh tokens and a timestamp when they expire. +// Additionally, it contains the refresher to get new tokens +type tokenRefresher struct { + Token // Refresher is the function that refreshes the token Refresher func(string) (*TokenResponse, time.Time, error) } @@ -44,7 +50,8 @@ type tokenLock struct { mu sync.Mutex // The token fields protected by the lock - t *token + // This token struct contains a refresher + t *tokenRefresher } // Access gets the OAuth access token used for contacting the server API @@ -57,17 +64,17 @@ func (l *tokenLock) Access() (string, error) { // The tokens are not expired yet // So they should be valid, re-login not neede if !l.expired() { - return l.t.access, nil + return l.t.Access, nil } // Check if refresh is even possible by doing a simple check if the refresh token is empty // This is not needed but reduces API calls to the server - if l.t.refresh == "" { + if l.t.Refresh == "" { return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0) } // Otherwise refresh and then later return the access token if we are successful - tr, s, err := l.t.Refresher(l.t.refresh) + tr, s, err := l.t.Refresher(l.t.Refresh) if err != nil { // We have failed to ensure the tokens due to refresh not working return "", errors.Wrap( @@ -76,37 +83,50 @@ func (l *tokenLock) Access() (string, error) { if tr == nil { return "", errors.New("No token response after refreshing") } - l.updateInternal(*tr, s) - return l.t.access, nil + r := *tr + e := s.Add(time.Second * time.Duration(r.Expires)) + t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e} + l.updateInternal(t) + return l.t.Access, nil } -// Clear completely clears the token structure -// This is useful for forcing re-authorization -func (l *tokenLock) Clear() { +// UpdateResponse updates the structure using the server response and locks +func (l *tokenLock) UpdateResponse(r TokenResponse, s time.Time) { l.mu.Lock() - l.t = &token{} + e := s.Add(time.Second * time.Duration(r.Expires)) + t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e} + l.updateInternal(t) l.mu.Unlock() } -// updateInternal updates the structure using the response without locking -func (l *tokenLock) updateInternal(r TokenResponse, s time.Time) { - l.t.access = r.Access - l.t.refresh = r.Refresh - l.t.expiredTimestamp = s.Add(time.Second * time.Duration(r.Expires)) +// updateInternal updates the token structure internally but does not lock +func (l *tokenLock) updateInternal(r Token) { + l.t.Access = r.Access + l.t.Refresh = r.Refresh + l.t.ExpiredTimestamp = r.ExpiredTimestamp } -// Update updates the structure usign the response and locks -func (l *tokenLock) Update(r TokenResponse, s time.Time) { +// Update updates the token structure using the internal function but locks +func (l *tokenLock) Update(r Token) { l.mu.Lock() - l.updateInternal(r, s) + l.updateInternal(r) l.mu.Unlock() } + +// Get gets the tokens into a public struct +func (l *tokenLock) Get() Token { + // TODO: Check nil? + l.mu.Lock() + defer l.mu.Unlock() + return l.t.Token +} + // SetExpired overrides the timestamp to the current time // This marks the tokens as expired func (l *tokenLock) SetExpired() { l.mu.Lock() - l.t.expiredTimestamp = time.Now() + l.t.ExpiredTimestamp = time.Now() l.mu.Unlock() } @@ -114,5 +134,5 @@ func (l *tokenLock) SetExpired() { // This is only called internally and thus does not lock func (l *tokenLock) expired() bool { now := time.Now() - return !now.Before(l.t.expiredTimestamp) + return !now.Before(l.t.ExpiredTimestamp) } diff --git a/internal/server/server.go b/internal/server/server.go index 78f6472..1585264 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -77,6 +77,10 @@ func ShouldRenewButton(srv Server) bool { return true } +func UpdateTokens(srv Server, t oauth.Token) { + srv.OAuth().UpdateTokens(t) +} + func OAuthURL(srv Server, name string) (string, error) { return srv.OAuth().AuthURL(name, srv.TemplateAuth()) } @@ -134,22 +138,33 @@ func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) { return &ps, nil } -func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string, string, error) { +type ConfigData struct { + // The configuration + Config string + + // The type of configuration + Type string + + // The tokens + Tokens oauth.Token +} + +func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) { b, err := srv.Base() if err != nil { - return "", "", err + return nil, err } pID := b.Profiles.Current key, err := wireguard.GenerateKey() if err != nil { - return "", "", err + return nil, err } pub := key.PublicKey().String() cfg, proto, exp, err := APIConnectWireguard(srv, pID, pub, preferTCP, openVPNSupport) if err != nil { - return "", "", err + return nil, err } // Store start and end time @@ -164,25 +179,39 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string cfg = wireguard.ConfigAddKey(cfg, key) } - return cfg, proto, nil + t := oauth.Token{} + + o := srv.OAuth() + if o != nil { + t = o.Token() + } + + return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil } -func openVPNGetConfig(srv Server, preferTCP bool) (string, string, error) { +func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) { b, err := srv.Base() if err != nil { - return "", "", err + return nil, err } pid := b.Profiles.Current cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP) if err != nil { - return "", "", err + return nil, err } // Store start and end time b.StartTime = time.Now() b.EndTime = exp - return cfg, "openvpn", nil + t := oauth.Token{} + + o := srv.OAuth() + if o != nil { + t = o.Token() + } + + return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil } func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { @@ -237,10 +266,10 @@ func RefreshEndpoints(srv Server) error { return b.InitializeEndpoints() } -func Config(server Server, wireguardSupport bool, preferTCP bool) (string, string, error) { +func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) { p, err := CurrentProfile(server) if err != nil { - return "", "", err + return nil, err } ovpn := p.SupportsOpenVPN() @@ -266,7 +295,7 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (string, strin return openVPNGetConfig(server, preferTCP) // The config supports no available protocol because the profile only supports WireGuard but the client doesn't default: - return "", "", errors.Errorf("no supported protocol found") + return nil, errors.Errorf("no supported protocol found") } } diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py index 1090619..f0f31d6 100644 --- a/wrappers/python/eduvpn_common/loader.py +++ b/wrappers/python/eduvpn_common/loader.py @@ -5,7 +5,7 @@ from ctypes import CDLL, c_char_p, c_int, c_void_p, cdll from eduvpn_common import __version__ from eduvpn_common.types import ( - ConfigError, + cToken, DataError, ReadRxBytes, VPNStateChange, @@ -67,6 +67,7 @@ def initialize_functions(lib: CDLL) -> None: c_char_p ], c_void_p lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None + lib.FreeConfig.argtypes, lib.FreeConfig.restype = [c_void_p], None lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [ c_void_p ], None @@ -81,17 +82,20 @@ def initialize_functions(lib: CDLL) -> None: c_char_p, c_char_p, c_int, - ], ConfigError + cToken, + ], DataError lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [ c_char_p, c_char_p, c_int, - ], ConfigError + cToken, + ], DataError lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [ c_char_p, c_char_p, c_int, - ], ConfigError + cToken, + ], DataError lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [ c_char_p ], DataError diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index 3cb45e1..304e2e8 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -1,13 +1,14 @@ import threading -from ctypes import c_int +from ctypes import cast, c_void_p, c_int, pointer from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from eduvpn_common.discovery import DiscoOrganizations, DiscoServers, get_disco_organizations, get_disco_servers from eduvpn_common.event import EventHandler from eduvpn_common.loader import initialize_functions, load_lib -from eduvpn_common.server import Profiles, Server, get_transition_server, get_servers +from eduvpn_common.server import Profiles, Config, Token, encode_tokens, get_config, Server, get_transition_server, get_servers from eduvpn_common.state import State, StateType from eduvpn_common.types import ReadRxBytes, VPNStateChange, decode_res, encode_args, get_data_error, get_bool +from eduvpn_common.types import VPNStateChange, ReadRXBytes, cToken, decode_res, encode_args, get_data_error, get_bool class EduVPN(object): @@ -219,26 +220,27 @@ class EduVPN(object): if remove_err: raise remove_err - def get_config(self, identifier: str, func: Any, prefer_tcp: bool = False) -> Tuple[str, str]: + def get_config(self, identifier: str, func: Any, prefer_tcp: bool = False, tokens: Optional[Token] = None) -> Optional[Config]: """Get an OpenVPN/WireGuard configuration from the server :param identifier: str: The identifier of the server, e.g. URL or ORG ID :param func: Any: The Go function to call :param prefer_tcp: bool: (Default value = False): Whether or not to prefer TCP + :param tokens: Optional[Token] (Default value = None): The OAuth tokens if available :meta private: :raises WrappedError: An error by the Go library :return: The configuration and configuration type ('openvpn' or 'wireguard') - :rtype: Tuple[str, str] + :rtype: Config """ # Because it could be the case that a profile callback is started, store a threading event # In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set # The event is set in self.set_profile self.profile_event = threading.Event() - config, config_type, config_err = self.go_function(func, identifier, prefer_tcp) + config, config_err = self.go_function(func, identifier, prefer_tcp, encode_tokens(tokens), decode_func=lambda lib, x: get_data_error(lib, x, get_config)) self.profile_event = None self.location_event = None @@ -246,49 +248,55 @@ class EduVPN(object): if config_err: raise config_err - return config, config_type + return config def get_config_custom_server( - self, url: str, prefer_tcp: bool = False - ) -> Tuple[str, str]: + self, url: str, prefer_tcp: bool = False, tokens: Optional[Token] = None + ) -> Optional[Config]: """Get an OpenVPN/WireGuard configuration from a custom server :param url: str: The URL of the custom server :param prefer_tcp: bool: (Default value = False): Whether or not to prefer TCP + :param tokens: Optional[Token] (Default value = None): The OAuth tokens if available :raises WrappedError: An error by the Go library :return: The configuration and configuration type ('openvpn' or 'wireguard') - :rtype: Tuple[str, str] + :rtype: Config """ - return self.get_config(url, self.lib.GetConfigCustomServer, prefer_tcp) + return self.get_config(url, self.lib.GetConfigCustomServer, prefer_tcp, tokens) def get_config_institute_access( - self, url: str, prefer_tcp: bool = False - ) -> Tuple[str, str]: + self, url: str, prefer_tcp: bool = False, tokens: Optional[Token] = None + ) -> Optional[Config]: """Get an OpenVPN/WireGuard configuration from an institute access server :param url: str: The URL of the institute access server. Use the one from Discovery :param prefer_tcp: bool: (Default value = False): Whether or not to prefer TCP + :param tokens: Optional[Token] (Default value = None): The OAuth tokens if available :raises WrappedError: An error by the Go library :return: The configuration and configuration type ('openvpn' or 'wireguard') - :rtype: Tuple[str, str] + :rtype: Config """ - return self.get_config(url, self.lib.GetConfigInstituteAccess, prefer_tcp) + return self.get_config(url, self.lib.GetConfigInstituteAccess, prefer_tcp, tokens) def get_config_secure_internet( - self, org_id: str, prefer_tcp: bool = False - ) -> Tuple[str, str]: + self, org_id: str, prefer_tcp: bool = False, tokens: Optional[Token] = None + ) -> Optional[Config]: """Get an OpenVPN/WireGuard configuration from a secure internet server :param org_id: str: The organization ID of the secure internet server. Use the one from Discovery :param prefer_tcp: bool: (Default value = False): Whether or not to prefer TCP + :param tokens: Optional[Token] (Default value = None): The OAuth tokens if available :raises WrappedError: An error by the Go library + + :return: The configuration and configuration type ('openvpn' or 'wireguard') + :rtype: Config """ - return self.get_config(org_id, self.lib.GetConfigSecureInternet, prefer_tcp) + return self.get_config(org_id, self.lib.GetConfigSecureInternet, prefer_tcp, tokens) def go_back(self) -> None: """Go back in the FSM""" @@ -539,7 +547,6 @@ def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> in return 1 return 0 - def add_as_global_object(eduvpn: EduVPN) -> bool: """Add the provided parameter to the global objects lists so we can call the callback diff --git a/wrappers/python/eduvpn_common/server.py b/wrappers/python/eduvpn_common/server.py index 380623d..d10584e 100644 --- a/wrappers/python/eduvpn_common/server.py +++ b/wrappers/python/eduvpn_common/server.py @@ -2,7 +2,7 @@ from ctypes import CDLL, POINTER, c_void_p, cast from datetime import datetime from typing import List, Optional, Type -from eduvpn_common.types import cServer, cServerLocations, cServerProfiles, cServers +from eduvpn_common.types import cConfig, cServer, cServerLocations, cServerProfiles, cServers, cToken class Profile: @@ -20,6 +20,34 @@ class Profile: def __str__(self): return self.display_name +class Token: + """The class that represents oauth Tokens + + :param: access: str: The access token + :param: refresh: str: The refresh token + :param: expired: int: The expire unix time + """ + def __init__(self, access: str, refresh: str, expired: int): + self.access = access + self.refresh = refresh + self.expires = expired + + +class Config: + """The class that represents an OpenVPN/WireGuard config + + :param: config: str: The config string + :param: config_type: str: The type of config, openvpn/wireguard + :param: tokens: Optional[Token]: The tokens + """ + def __init__(self, config: str, config_type: str, tokens: Optional[Token]): + self.config = config + self.config_type = config_type + self.tokens = tokens + + def __str__(self): + return self.config + class Profiles: """The class that represents a list of profiles @@ -347,3 +375,36 @@ def get_locations(lib: CDLL, ptr: c_void_p) -> Optional[List[str]]: lib.FreeSecureLocations(ptr) return location_list return None + + +def get_config(lib: CDLL, ptr: c_void_p) -> Optional[Config]: + """Get the config from the Go library as a C structure and return a Python usable structure + + :param lib: CDLL: The Go shared library + :param ptr: c_void_p: The C pointer to the confg structure + + :meta private: + + :return: The configuration if there is any + :rtype: Optional[Config] + """ + # TODO: FREE + if ptr: + config = cast(ptr, POINTER(cConfig)).contents + cfg = config.config.decode("utf-8") + cfg_type = config.config_type.decode("utf-8") + tokens = None + if config.token: + token_struct = config.token.contents + tokens = Token(token_struct.access.decode("utf-8"), token_struct.refresh.decode("utf-8"), token_struct.expired) + + config_class = Config(cfg, cfg_type, tokens) + lib.FreeConfig(ptr) + return config_class + return None + +def encode_tokens(arg: Optional[Token]) -> cToken: + if arg is None: + return cToken("".encode("utf-8"), "".encode("utf-8"), 0) + return cToken(arg.access.encode("utf-8"), arg.refresh.encode("utf-8"), arg.expires) + diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py index 7e3ce9a..e4f8e26 100644 --- a/wrappers/python/eduvpn_common/types.py +++ b/wrappers/python/eduvpn_common/types.py @@ -15,6 +15,24 @@ from typing import Any, Callable, Iterator, List, Optional, Tuple from eduvpn_common.error import ErrorLevel, WrappedError +class cToken(Structure): + """The C type that represents the Token as forwarded to the Go library + + :meta private: + """ + _fields_ = [ + ("access", c_char_p), + ("refresh", c_char_p), + ("expired", c_ulonglong), + ] + + +class cConfig(Structure): + """The C type that represents the data that gets by the Go library returned when a config is obtained + + :meta private: + """ + _fields_ = [("config", c_char_p), ("config_type", c_char_p), ("token", POINTER(cToken))] class cError(Structure): """The C type that represents the Error as returned by the Go library @@ -156,19 +174,10 @@ class DataError(Structure): _fields_ = [("data", c_void_p), ("error", c_void_p)] -class ConfigError(Structure): - """The C type that represents the data that gets by the Go library returned when a config is obtained - - :meta private: - """ - _fields_ = [("config", c_void_p), ("config_type", c_void_p), ("error", c_void_p)] - - # The type for a Go state change callback VPNStateChange = CFUNCTYPE(c_int, c_char_p, c_int, c_int, c_void_p) ReadRxBytes = CFUNCTYPE(c_ulonglong) - def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]: """Encode the arguments ready to be used by the Go library @@ -182,8 +191,11 @@ def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]: """ for arg, t in zip(args, types): # c_char_p needs the str to be encoded to bytes - if t is c_char_p: - arg = arg.encode("utf-8") + encode_map = { + c_char_p: lambda x: x.encode("utf-8"), + } + if t in encode_map: + arg = encode_map[t](arg) yield arg @@ -201,7 +213,6 @@ def decode_res(res: Any) -> Any: c_int: get_bool, c_void_p: get_error, DataError: get_data_error, - ConfigError: get_config_error, } return decode_map.get(res, lambda lib, x: x) @@ -268,25 +279,6 @@ def get_error(lib: CDLL, ptr: c_void_p) -> Optional[WrappedError]: return wrapped -def get_config_error( - lib: CDLL, config_error: ConfigError -) -> Tuple[str, str, Optional[WrappedError]]: - """Convert a C config structure to a Python usable config structure - - :param lib: CDLL: The Go shared library - :param config_error: ConfigError: The config error structure - - :meta private: - - :return: The configuration, configuration type ('openvpn'/'wireguard') and an optional error - :rtype: Tuple[str, str, Optional[WrappedError]] - """ - config = get_ptr_string(lib, config_error.config) - config_type = get_ptr_string(lib, config_error.config_type) - err = get_error(lib, config_error.error) - return config, config_type, err - - def get_data_error( lib: CDLL, data_error: DataError, data_conv: Callable = get_ptr_string ) -> Tuple[Any, Optional[WrappedError]]: |
