summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-12-20 15:43:55 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-12-21 18:28:50 +0100
commit12838c19514459974cf0a71c42f1248b1cb9419c (patch)
treea4254d20bb7b0ef49a2fa6c12753eb4c5acb64d1
parent6981666c6d8f639a1ff9c09a3bc08769e19928af (diff)
Exports + OAuth + Server: Forward tokens to getting a config
-rw-r--r--client/server.go107
-rw-r--r--exports/exports.go86
-rw-r--r--internal/oauth/oauth.go24
-rw-r--r--internal/oauth/token.go74
-rw-r--r--internal/server/server.go53
-rw-r--r--wrappers/python/eduvpn_common/loader.py12
-rw-r--r--wrappers/python/eduvpn_common/main.py43
-rw-r--r--wrappers/python/eduvpn_common/server.py63
-rw-r--r--wrappers/python/eduvpn_common/types.py54
9 files changed, 359 insertions, 157 deletions
diff --git a/client/server.go b/client/server.go
index 6802a47..0bb37a8 100644
--- a/client/server.go
+++ b/client/server.go
@@ -9,12 +9,14 @@ import (
"github.com/go-errors/errors"
)
+type ConfigData = server.ConfigData
+
// getConfigAuth gets a config with authorization and authentication.
// It also asks for a profile if no valid profile is found.
-func (c *Client) getConfigAuth(srv server.Server, preferTCP bool) (string, string, error) {
- err := c.ensureLogin(srv)
+func (c *Client) getConfigAuth(srv server.Server, preferTCP bool, t oauth.Token) (*ConfigData, error) {
+ err := c.ensureLogin(srv, t)
if err != nil {
- return "", "", err
+ return nil, err
}
// TODO(jwijenbergh): Should we check if it returns false?
@@ -22,13 +24,13 @@ func (c *Client) getConfigAuth(srv server.Server, preferTCP bool) (string, strin
ok, err := server.HasValidProfile(srv, c.SupportsWireguard)
if err != nil {
- return "", "", err
+ return nil, err
}
// No valid profile, ask for one
if !ok {
if err = c.askProfile(srv); err != nil {
- return "", "", err
+ return nil, err
}
}
@@ -38,29 +40,28 @@ func (c *Client) getConfigAuth(srv server.Server, preferTCP bool) (string, strin
// retryConfigAuth retries the getConfigAuth function if the tokens are invalid.
// If OAuth is cancelled, it makes sure that we only forward the error as additional info.
-func (c *Client) retryConfigAuth(srv server.Server, preferTCP bool) (string, string, error) {
- cfg, cfgType, err := c.getConfigAuth(srv, preferTCP)
+func (c *Client) retryConfigAuth(srv server.Server, preferTCP bool, t oauth.Token) (*ConfigData, error) {
+ cfg, err := c.getConfigAuth(srv, preferTCP, t)
if err == nil {
- return cfg, cfgType, nil
- }
- if err != nil {
- // Only retry if the error is that the tokens are invalid
- tErr := &oauth.TokensInvalidError{}
- if errors.As(err, &tErr) {
- cfg, cfgType, err = c.getConfigAuth(srv, preferTCP)
- if err == nil {
- return cfg, cfgType, nil
- }
+ return cfg, nil
+ }
+ // Only retry if the error is that the tokens are invalid
+ tErr := &oauth.TokensInvalidError{}
+ if errors.As(err, &tErr) {
+ // TODO: Is passing empty tokens correct here?
+ cfg, err = c.getConfigAuth(srv, preferTCP, oauth.Token{})
+ if err == nil {
+ return cfg, nil
}
- c.goBackInternal()
}
- return "", "", err
+ c.goBackInternal()
+ return nil, err
}
// getConfig gets an OpenVPN/WireGuard configuration by contacting the server, moving the FSM towards the DISCONNECTED state and then saving the local configuration file.
-func (c *Client) getConfig(srv server.Server, preferTCP bool) (string, string, error) {
+func (c *Client) getConfig(srv server.Server, preferTCP bool, t oauth.Token) (*ConfigData, error) {
if c.InFSMState(StateDeregistered) {
- return "", "", errors.Errorf("getConfig attempt in '%v'", StateDeregistered)
+ return nil, errors.Errorf("getConfig attempt in '%v'", StateDeregistered)
}
// Refresh the server endpoints
@@ -70,14 +71,14 @@ func (c *Client) getConfig(srv server.Server, preferTCP bool) (string, string, e
c.Logger.Warningf("failed to refresh server endpoints: %v", err)
}
- cfg, cfgType, err := c.retryConfigAuth(srv, preferTCP)
+ cfg, err := c.retryConfigAuth(srv, preferTCP, t)
if err != nil {
- return "", "", err
+ return nil, err
}
srv1, err := c.Servers.GetCurrentServer()
if err != nil {
- return "", "", err
+ return nil, err
}
// Signal the server display info
@@ -91,7 +92,7 @@ func (c *Client) getConfig(srv server.Server, preferTCP bool) (string, string, e
err.Error(), err.(*errors.Error).ErrorStack())
}
- return cfg, cfgType, nil
+ return cfg, nil
}
// SetSecureLocation sets the location for the current secure location server. countryCode is the secure location to be chosen.
@@ -227,7 +228,7 @@ func (c *Client) AddInstituteServer(url string) (srv server.Server, err error) {
c.FSM.GoTransition(StateChosenServer)
// Authorize it
- if err = c.ensureLogin(srv); err != nil {
+ if err = c.ensureLogin(srv, oauth.Token{}); err != nil {
// Removing is best effort
_ = c.RemoveInstituteAccess(url)
return nil, err
@@ -285,7 +286,7 @@ func (c *Client) AddSecureInternetHomeServer(orgID string) (srv server.Server, e
c.FSM.GoTransition(StateChosenServer)
// Authorize it
- if err = c.ensureLogin(srv); err != nil {
+ if err = c.ensureLogin(srv, oauth.Token{}); err != nil {
// Removing is best effort
_ = c.RemoveSecureInternet()
return nil, err
@@ -331,7 +332,7 @@ func (c *Client) AddCustomServer(url string) (srv server.Server, err error) {
c.FSM.GoTransition(StateChosenServer)
// Authorize it
- if err = c.ensureLogin(srv); err != nil {
+ if err = c.ensureLogin(srv, oauth.Token{}); err != nil {
// removing is best effort
_ = c.RemoveCustomServer(url)
return nil, err
@@ -344,7 +345,7 @@ func (c *Client) AddCustomServer(url string) (srv server.Server, err error) {
// GetConfigInstituteAccess gets a configuration for an Institute Access Server.
// It ensures that the Institute Access Server exists by creating or using an existing one with the url.
// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel.
-func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool) (cfg string, cfgType string, err error) {
+func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool, t oauth.Token) (cfg *ConfigData, err error) {
defer func() {
if err != nil {
c.logError(err)
@@ -353,7 +354,7 @@ func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool) (cfg strin
// Not supported with Let's Connect!
if c.isLetsConnect() {
- return "", "", errors.Errorf("discovery with Let's Connect is not supported")
+ return nil, errors.Errorf("discovery with Let's Connect is not supported")
}
c.FSM.GoTransition(StateLoadingServer)
@@ -362,28 +363,28 @@ func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool) (cfg strin
var srv *server.InstituteAccessServer
if srv, err = c.Servers.GetInstituteAccess(url); err != nil {
c.goBackInternal()
- return "", "", err
+ return nil, err
}
// Set the server as the current
if err = c.Servers.SetInstituteAccess(srv); err != nil {
- return "", "", err
+ return nil, err
}
// The server has now been chosen
c.FSM.GoTransition(StateChosenServer)
- if cfg, cfgType, err = c.getConfig(srv, preferTCP); err != nil {
+ if cfg, err = c.getConfig(srv, preferTCP, t); err != nil {
c.goBackInternal()
}
- return cfg, cfgType, err
+ return cfg, err
}
// GetConfigSecureInternet gets a configuration for a Secure Internet Server.
// It ensures that the Secure Internet Server exists by creating or using an existing one with the orgID.
// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel.
-func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool) (cfg string, cfgType string, err error) {
+func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool, t oauth.Token) (cfg *ConfigData, err error) {
defer func() {
if err != nil {
c.logError(err)
@@ -392,7 +393,7 @@ func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool) (cfg stri
// Not supported with Let's Connect!
if c.isLetsConnect() {
- return "", "", errors.Errorf("discovery with Let's Connect is not supported")
+ return nil, errors.Errorf("discovery with Let's Connect is not supported")
}
c.FSM.GoTransition(StateLoadingServer)
@@ -401,27 +402,27 @@ func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool) (cfg stri
var srv *server.SecureInternetHomeServer
if srv, err = c.Servers.GetSecureInternetHomeServer(); err != nil {
c.goBackInternal()
- return "", "", err
+ return nil, err
}
// Set the server as the current
if err = c.Servers.SetSecureInternet(srv); err != nil {
- return "", "", err
+ return nil, err
}
c.FSM.GoTransition(StateChosenServer)
- if cfg, cfgType, err = c.getConfig(srv, preferTCP); err != nil {
+ if cfg, err = c.getConfig(srv, preferTCP, t); err != nil {
c.goBackInternal()
}
- return cfg, cfgType, err
+ return cfg, err
}
// GetConfigCustomServer gets a configuration for a Custom Server.
// It ensures that the Custom Server exists by creating or using an existing one with the url.
// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel.
-func (c *Client) GetConfigCustomServer(url string, preferTCP bool) (cfg string, cfgType string, err error) {
+func (c *Client) GetConfigCustomServer(url string, preferTCP bool, t oauth.Token) (cfg *ConfigData, err error) {
defer func() {
if err != nil {
c.logError(err)
@@ -429,7 +430,7 @@ func (c *Client) GetConfigCustomServer(url string, preferTCP bool) (cfg string,
}()
if url, err = util.EnsureValidURL(url); err != nil {
- return "", "", err
+ return nil, err
}
c.FSM.GoTransition(StateLoadingServer)
@@ -438,21 +439,21 @@ func (c *Client) GetConfigCustomServer(url string, preferTCP bool) (cfg string,
var srv *server.InstituteAccessServer
if srv, err = c.Servers.GetCustomServer(url); err != nil {
c.goBackInternal()
- return "", "", err
+ return nil, err
}
// Set the server as the current
if err = c.Servers.SetCustomServer(srv); err != nil {
- return "", "", err
+ return nil, err
}
c.FSM.GoTransition(StateChosenServer)
- if cfg, cfgType, err = c.getConfig(srv, preferTCP); err != nil {
+ if cfg, err = c.getConfig(srv, preferTCP, t); err != nil {
c.goBackInternal()
}
- return cfg, cfgType, err
+ return cfg, err
}
// askSecureLocation asks the user to choose a Secure Internet location by moving the FSM to the STATE_ASK_LOCATION state.
@@ -512,7 +513,7 @@ func (c *Client) RenewSession() (err error) {
}
server.MarkTokensForRenew(srv)
- return c.ensureLogin(srv)
+ return c.ensureLogin(srv, oauth.Token{})
}
// ShouldRenewButton returns true if the renew button should be shown
@@ -536,7 +537,7 @@ func (c *Client) ShouldRenewButton() bool {
// ensureLogin logs the user back in if needed.
// It runs the FSM transitions to ask for user input.
-func (c *Client) ensureLogin(srv server.Server) (err error) {
+func (c *Client) ensureLogin(srv server.Server, ct oauth.Token) (err error) {
// Relogin with oauth
// This moves the state to authorized
if !server.NeedsRelogin(srv) {
@@ -545,6 +546,16 @@ func (c *Client) ensureLogin(srv server.Server) (err error) {
return nil
}
+ // Try again but update the tokens using the client provided tokens
+ server.UpdateTokens(srv, ct)
+ if !server.NeedsRelogin(srv) {
+ // OAuth was valid, ensure we are in the authorized state
+ c.FSM.GoTransition(StateAuthorized)
+ return nil
+ }
+
+ // Tokens are not valid or the client gave an error when updating tokens
+ // Otherwise, do the OAuth exchange
var url string
if url, err = server.OAuthURL(srv, c.Name); err != nil {
return err
diff --git a/exports/exports.go b/exports/exports.go
index e374661..904fbec 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -5,6 +5,18 @@ package main
#include "error.h"
typedef long long int (*ReadRxBytes)();
+typedef struct token {
+ const char* access;
+ const char* refresh;
+ unsigned long long int expired;
+} token;
+
+typedef struct configData {
+ const char* config;
+ const char* config_type;
+ token* tokens;
+} configData;
+
typedef int (*PythonCB)(const char* name, int oldstate, int newstate, void* data);
static long long int get_read_rx_bytes(ReadRxBytes read)
@@ -20,8 +32,10 @@ import "C"
import (
"unsafe"
+ "time"
"github.com/eduvpn/eduvpn-common/internal/log"
+ "github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/go-errors/errors"
"github.com/eduvpn/eduvpn-common/client"
@@ -249,20 +263,56 @@ func RemoveCustomServer(name *C.char, url *C.char) *C.error {
return getError(removeErr)
}
+func cToken(t oauth.Token) *C.token {
+ cTok := (*C.token)(C.malloc(C.size_t(unsafe.Sizeof(C.token{}))))
+ cTok.access = C.CString(t.Access)
+ cTok.refresh = C.CString(t.Refresh)
+ cTok.expired = C.ulonglong(t.ExpiredTimestamp.Unix())
+ return cTok
+}
+
+func cConfig(config *client.ConfigData) *C.configData {
+ // No config so return nil pointer
+ if config == nil {
+ return nil
+ }
+ cConf := (*C.configData)(C.malloc(C.size_t(unsafe.Sizeof(C.configData{}))))
+ cConf.config = C.CString(config.Config)
+ cConf.config_type = C.CString(config.Type)
+ cConf.tokens = cToken(config.Tokens)
+ return cConf
+}
+
+//export FreeConfig
+func FreeConfig(config *C.configData) {
+ C.free(unsafe.Pointer(config.config))
+ C.free(unsafe.Pointer(config.config_type))
+ C.free(unsafe.Pointer(config.tokens.access))
+ C.free(unsafe.Pointer(config.tokens.refresh))
+ C.free(unsafe.Pointer(config.tokens))
+ C.free(unsafe.Pointer(config))
+}
+
//export GetConfigSecureInternet
func GetConfigSecureInternet(
name *C.char,
orgID *C.char,
preferTCP C.int,
-) (*C.char, *C.char, *C.error) {
+ prevTokens C.token,
+) (*C.configData, *C.error) {
nameStr := C.GoString(name)
state, stateErr := GetVPNState(nameStr)
if stateErr != nil {
- return nil, nil, getError(stateErr)
+ return nil, getError(stateErr)
}
preferTCPBool := preferTCP == 1
- config, configType, configErr := state.GetConfigSecureInternet(C.GoString(orgID), preferTCPBool)
- return C.CString(config), C.CString(configType), getError(configErr)
+ t := oauth.Token{
+ Access: C.GoString(prevTokens.access),
+ Refresh: C.GoString(prevTokens.refresh),
+ ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0),
+ }
+ cfg, configErr := state.GetConfigSecureInternet(C.GoString(orgID), preferTCPBool, t)
+ return cConfig(cfg), getError(configErr)
}
//export GetConfigInstituteAccess
@@ -270,15 +320,21 @@ func GetConfigInstituteAccess(
name *C.char,
url *C.char,
preferTCP C.int,
-) (*C.char, *C.char, *C.error) {
+ prevTokens C.token,
+) (*C.configData, *C.error) {
nameStr := C.GoString(name)
state, stateErr := GetVPNState(nameStr)
if stateErr != nil {
- return nil, nil, getError(stateErr)
+ return nil, getError(stateErr)
}
preferTCPBool := preferTCP == 1
- config, configType, configErr := state.GetConfigInstituteAccess(C.GoString(url), preferTCPBool)
- return C.CString(config), C.CString(configType), getError(configErr)
+ t := oauth.Token{
+ Access: C.GoString(prevTokens.access),
+ Refresh: C.GoString(prevTokens.refresh),
+ ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0),
+ }
+ cfg, configErr := state.GetConfigInstituteAccess(C.GoString(url), preferTCPBool, t)
+ return cConfig(cfg), getError(configErr)
}
//export GetConfigCustomServer
@@ -286,15 +342,21 @@ func GetConfigCustomServer(
name *C.char,
url *C.char,
preferTCP C.int,
-) (*C.char, *C.char, *C.error) {
+ prevTokens C.token,
+) (*C.configData, *C.error) {
nameStr := C.GoString(name)
state, stateErr := GetVPNState(nameStr)
if stateErr != nil {
- return nil, nil, getError(stateErr)
+ return nil, getError(stateErr)
}
preferTCPBool := preferTCP == 1
- config, configType, configErr := state.GetConfigCustomServer(C.GoString(url), preferTCPBool)
- return C.CString(config), C.CString(configType), getError(configErr)
+ t := oauth.Token{
+ Access: C.GoString(prevTokens.access),
+ Refresh: C.GoString(prevTokens.refresh),
+ ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0),
+ }
+ cfg, configErr := state.GetConfigCustomServer(C.GoString(url), preferTCPBool, t)
+ return cConfig(cfg), getError(configErr)
}
//export SetProfileID
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index ce86337..6d21c82 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -196,10 +196,19 @@ func (oauth *OAuth) SetTokenExpired() {
// SetTokenRenew sets the tokens for renewal by completely clearing the structure.
func (oauth *OAuth) SetTokenRenew() {
if oauth.token != nil {
- oauth.token.Clear()
+ oauth.token.Update(Token{})
}
}
+func (oauth *OAuth) Token() Token {
+ t := Token{}
+ if oauth.token != nil {
+ t = oauth.token.Get()
+ }
+
+ return t
+}
+
// tokensWithAuthCode gets the access and refresh tokens using the authorization code
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
@@ -239,10 +248,17 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
return errors.New("No token response after authorization code")
}
- oauth.token.Update(*tr, now)
+ oauth.token.UpdateResponse(*tr, now)
return nil
}
+func (oauth *OAuth) UpdateTokens(t Token) {
+ if oauth.token == nil {
+ oauth.token = &tokenLock{t: &tokenRefresher{Refresher: oauth.refreshResponse}}
+ }
+ oauth.token.Update(t)
+}
+
// refreshResponse gets the refresh token response with a refresh token
// This response contains the access and refresh tokens, together with a timestamp
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
@@ -420,8 +436,8 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s
return "", errors.WrapPrefix(err, "genState error", 0)
}
- // Fill the oauth tokens
- oauth.token = &tokenLock{t: &token{Refresher: oauth.refreshResponse}}
+ // Re-initialize the token structure
+ oauth.UpdateTokens(Token{})
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
oauth.session = exchangeSession{
diff --git a/internal/oauth/token.go b/internal/oauth/token.go
index 855677c..4ed8f43 100644
--- a/internal/oauth/token.go
+++ b/internal/oauth/token.go
@@ -23,17 +23,23 @@ type TokenResponse struct {
Expires int64 `json:"expires_in"`
}
-// token is a structure that contains our access and refresh tokens and a timestamp when they expire.
-type token struct {
- // Access is the access token returned by the server
- access string
+// The public type that can be passed to an update function
+// It contains our access and refresh tokens with a timestamp
+type Token struct {
+ // Access is the Access token returned by the server
+ Access string
- // Refresh token is the refresh token returned by the server
- refresh string
+ // Refresh token is the Refresh token returned by the server
+ Refresh string
// ExpiredTimestamp is the Expires field but converted to a Go timestamp
- expiredTimestamp time.Time
+ ExpiredTimestamp time.Time
+}
+// tokenRefresher is a structure that contains our access and refresh tokens and a timestamp when they expire.
+// Additionally, it contains the refresher to get new tokens
+type tokenRefresher struct {
+ Token
// Refresher is the function that refreshes the token
Refresher func(string) (*TokenResponse, time.Time, error)
}
@@ -44,7 +50,8 @@ type tokenLock struct {
mu sync.Mutex
// The token fields protected by the lock
- t *token
+ // This token struct contains a refresher
+ t *tokenRefresher
}
// Access gets the OAuth access token used for contacting the server API
@@ -57,17 +64,17 @@ func (l *tokenLock) Access() (string, error) {
// The tokens are not expired yet
// So they should be valid, re-login not neede
if !l.expired() {
- return l.t.access, nil
+ return l.t.Access, nil
}
// Check if refresh is even possible by doing a simple check if the refresh token is empty
// This is not needed but reduces API calls to the server
- if l.t.refresh == "" {
+ if l.t.Refresh == "" {
return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0)
}
// Otherwise refresh and then later return the access token if we are successful
- tr, s, err := l.t.Refresher(l.t.refresh)
+ tr, s, err := l.t.Refresher(l.t.Refresh)
if err != nil {
// We have failed to ensure the tokens due to refresh not working
return "", errors.Wrap(
@@ -76,37 +83,50 @@ func (l *tokenLock) Access() (string, error) {
if tr == nil {
return "", errors.New("No token response after refreshing")
}
- l.updateInternal(*tr, s)
- return l.t.access, nil
+ r := *tr
+ e := s.Add(time.Second * time.Duration(r.Expires))
+ t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e}
+ l.updateInternal(t)
+ return l.t.Access, nil
}
-// Clear completely clears the token structure
-// This is useful for forcing re-authorization
-func (l *tokenLock) Clear() {
+// UpdateResponse updates the structure using the server response and locks
+func (l *tokenLock) UpdateResponse(r TokenResponse, s time.Time) {
l.mu.Lock()
- l.t = &token{}
+ e := s.Add(time.Second * time.Duration(r.Expires))
+ t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e}
+ l.updateInternal(t)
l.mu.Unlock()
}
-// updateInternal updates the structure using the response without locking
-func (l *tokenLock) updateInternal(r TokenResponse, s time.Time) {
- l.t.access = r.Access
- l.t.refresh = r.Refresh
- l.t.expiredTimestamp = s.Add(time.Second * time.Duration(r.Expires))
+// updateInternal updates the token structure internally but does not lock
+func (l *tokenLock) updateInternal(r Token) {
+ l.t.Access = r.Access
+ l.t.Refresh = r.Refresh
+ l.t.ExpiredTimestamp = r.ExpiredTimestamp
}
-// Update updates the structure usign the response and locks
-func (l *tokenLock) Update(r TokenResponse, s time.Time) {
+// Update updates the token structure using the internal function but locks
+func (l *tokenLock) Update(r Token) {
l.mu.Lock()
- l.updateInternal(r, s)
+ l.updateInternal(r)
l.mu.Unlock()
}
+
+// Get gets the tokens into a public struct
+func (l *tokenLock) Get() Token {
+ // TODO: Check nil?
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ return l.t.Token
+}
+
// SetExpired overrides the timestamp to the current time
// This marks the tokens as expired
func (l *tokenLock) SetExpired() {
l.mu.Lock()
- l.t.expiredTimestamp = time.Now()
+ l.t.ExpiredTimestamp = time.Now()
l.mu.Unlock()
}
@@ -114,5 +134,5 @@ func (l *tokenLock) SetExpired() {
// This is only called internally and thus does not lock
func (l *tokenLock) expired() bool {
now := time.Now()
- return !now.Before(l.t.expiredTimestamp)
+ return !now.Before(l.t.ExpiredTimestamp)
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 78f6472..1585264 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -77,6 +77,10 @@ func ShouldRenewButton(srv Server) bool {
return true
}
+func UpdateTokens(srv Server, t oauth.Token) {
+ srv.OAuth().UpdateTokens(t)
+}
+
func OAuthURL(srv Server, name string) (string, error) {
return srv.OAuth().AuthURL(name, srv.TemplateAuth())
}
@@ -134,22 +138,33 @@ func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) {
return &ps, nil
}
-func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string, string, error) {
+type ConfigData struct {
+ // The configuration
+ Config string
+
+ // The type of configuration
+ Type string
+
+ // The tokens
+ Tokens oauth.Token
+}
+
+func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) {
b, err := srv.Base()
if err != nil {
- return "", "", err
+ return nil, err
}
pID := b.Profiles.Current
key, err := wireguard.GenerateKey()
if err != nil {
- return "", "", err
+ return nil, err
}
pub := key.PublicKey().String()
cfg, proto, exp, err := APIConnectWireguard(srv, pID, pub, preferTCP, openVPNSupport)
if err != nil {
- return "", "", err
+ return nil, err
}
// Store start and end time
@@ -164,25 +179,39 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string
cfg = wireguard.ConfigAddKey(cfg, key)
}
- return cfg, proto, nil
+ t := oauth.Token{}
+
+ o := srv.OAuth()
+ if o != nil {
+ t = o.Token()
+ }
+
+ return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil
}
-func openVPNGetConfig(srv Server, preferTCP bool) (string, string, error) {
+func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) {
b, err := srv.Base()
if err != nil {
- return "", "", err
+ return nil, err
}
pid := b.Profiles.Current
cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP)
if err != nil {
- return "", "", err
+ return nil, err
}
// Store start and end time
b.StartTime = time.Now()
b.EndTime = exp
- return cfg, "openvpn", nil
+ t := oauth.Token{}
+
+ o := srv.OAuth()
+ if o != nil {
+ t = o.Token()
+ }
+
+ return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil
}
func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) {
@@ -237,10 +266,10 @@ func RefreshEndpoints(srv Server) error {
return b.InitializeEndpoints()
}
-func Config(server Server, wireguardSupport bool, preferTCP bool) (string, string, error) {
+func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) {
p, err := CurrentProfile(server)
if err != nil {
- return "", "", err
+ return nil, err
}
ovpn := p.SupportsOpenVPN()
@@ -266,7 +295,7 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (string, strin
return openVPNGetConfig(server, preferTCP)
// The config supports no available protocol because the profile only supports WireGuard but the client doesn't
default:
- return "", "", errors.Errorf("no supported protocol found")
+ return nil, errors.Errorf("no supported protocol found")
}
}
diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py
index 1090619..f0f31d6 100644
--- a/wrappers/python/eduvpn_common/loader.py
+++ b/wrappers/python/eduvpn_common/loader.py
@@ -5,7 +5,7 @@ from ctypes import CDLL, c_char_p, c_int, c_void_p, cdll
from eduvpn_common import __version__
from eduvpn_common.types import (
- ConfigError,
+ cToken,
DataError,
ReadRxBytes,
VPNStateChange,
@@ -67,6 +67,7 @@ def initialize_functions(lib: CDLL) -> None:
c_char_p
], c_void_p
lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None
+ lib.FreeConfig.argtypes, lib.FreeConfig.restype = [c_void_p], None
lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [
c_void_p
], None
@@ -81,17 +82,20 @@ def initialize_functions(lib: CDLL) -> None:
c_char_p,
c_char_p,
c_int,
- ], ConfigError
+ cToken,
+ ], DataError
lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [
c_char_p,
c_char_p,
c_int,
- ], ConfigError
+ cToken,
+ ], DataError
lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [
c_char_p,
c_char_p,
c_int,
- ], ConfigError
+ cToken,
+ ], DataError
lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [
c_char_p
], DataError
diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py
index 3cb45e1..304e2e8 100644
--- a/wrappers/python/eduvpn_common/main.py
+++ b/wrappers/python/eduvpn_common/main.py
@@ -1,13 +1,14 @@
import threading
-from ctypes import c_int
+from ctypes import cast, c_void_p, c_int, pointer
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from eduvpn_common.discovery import DiscoOrganizations, DiscoServers, get_disco_organizations, get_disco_servers
from eduvpn_common.event import EventHandler
from eduvpn_common.loader import initialize_functions, load_lib
-from eduvpn_common.server import Profiles, Server, get_transition_server, get_servers
+from eduvpn_common.server import Profiles, Config, Token, encode_tokens, get_config, Server, get_transition_server, get_servers
from eduvpn_common.state import State, StateType
from eduvpn_common.types import ReadRxBytes, VPNStateChange, decode_res, encode_args, get_data_error, get_bool
+from eduvpn_common.types import VPNStateChange, ReadRXBytes, cToken, decode_res, encode_args, get_data_error, get_bool
class EduVPN(object):
@@ -219,26 +220,27 @@ class EduVPN(object):
if remove_err:
raise remove_err
- def get_config(self, identifier: str, func: Any, prefer_tcp: bool = False) -> Tuple[str, str]:
+ def get_config(self, identifier: str, func: Any, prefer_tcp: bool = False, tokens: Optional[Token] = None) -> Optional[Config]:
"""Get an OpenVPN/WireGuard configuration from the server
:param identifier: str: The identifier of the server, e.g. URL or ORG ID
:param func: Any: The Go function to call
:param prefer_tcp: bool: (Default value = False): Whether or not to prefer TCP
+ :param tokens: Optional[Token] (Default value = None): The OAuth tokens if available
:meta private:
:raises WrappedError: An error by the Go library
:return: The configuration and configuration type ('openvpn' or 'wireguard')
- :rtype: Tuple[str, str]
+ :rtype: Config
"""
# Because it could be the case that a profile callback is started, store a threading event
# In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set
# The event is set in self.set_profile
self.profile_event = threading.Event()
- config, config_type, config_err = self.go_function(func, identifier, prefer_tcp)
+ config, config_err = self.go_function(func, identifier, prefer_tcp, encode_tokens(tokens), decode_func=lambda lib, x: get_data_error(lib, x, get_config))
self.profile_event = None
self.location_event = None
@@ -246,49 +248,55 @@ class EduVPN(object):
if config_err:
raise config_err
- return config, config_type
+ return config
def get_config_custom_server(
- self, url: str, prefer_tcp: bool = False
- ) -> Tuple[str, str]:
+ self, url: str, prefer_tcp: bool = False, tokens: Optional[Token] = None
+ ) -> Optional[Config]:
"""Get an OpenVPN/WireGuard configuration from a custom server
:param url: str: The URL of the custom server
:param prefer_tcp: bool: (Default value = False): Whether or not to prefer TCP
+ :param tokens: Optional[Token] (Default value = None): The OAuth tokens if available
:raises WrappedError: An error by the Go library
:return: The configuration and configuration type ('openvpn' or 'wireguard')
- :rtype: Tuple[str, str]
+ :rtype: Config
"""
- return self.get_config(url, self.lib.GetConfigCustomServer, prefer_tcp)
+ return self.get_config(url, self.lib.GetConfigCustomServer, prefer_tcp, tokens)
def get_config_institute_access(
- self, url: str, prefer_tcp: bool = False
- ) -> Tuple[str, str]:
+ self, url: str, prefer_tcp: bool = False, tokens: Optional[Token] = None
+ ) -> Optional[Config]:
"""Get an OpenVPN/WireGuard configuration from an institute access server
:param url: str: The URL of the institute access server. Use the one from Discovery
:param prefer_tcp: bool: (Default value = False): Whether or not to prefer TCP
+ :param tokens: Optional[Token] (Default value = None): The OAuth tokens if available
:raises WrappedError: An error by the Go library
:return: The configuration and configuration type ('openvpn' or 'wireguard')
- :rtype: Tuple[str, str]
+ :rtype: Config
"""
- return self.get_config(url, self.lib.GetConfigInstituteAccess, prefer_tcp)
+ return self.get_config(url, self.lib.GetConfigInstituteAccess, prefer_tcp, tokens)
def get_config_secure_internet(
- self, org_id: str, prefer_tcp: bool = False
- ) -> Tuple[str, str]:
+ self, org_id: str, prefer_tcp: bool = False, tokens: Optional[Token] = None
+ ) -> Optional[Config]:
"""Get an OpenVPN/WireGuard configuration from a secure internet server
:param org_id: str: The organization ID of the secure internet server. Use the one from Discovery
:param prefer_tcp: bool: (Default value = False): Whether or not to prefer TCP
+ :param tokens: Optional[Token] (Default value = None): The OAuth tokens if available
:raises WrappedError: An error by the Go library
+
+ :return: The configuration and configuration type ('openvpn' or 'wireguard')
+ :rtype: Config
"""
- return self.get_config(org_id, self.lib.GetConfigSecureInternet, prefer_tcp)
+ return self.get_config(org_id, self.lib.GetConfigSecureInternet, prefer_tcp, tokens)
def go_back(self) -> None:
"""Go back in the FSM"""
@@ -539,7 +547,6 @@ def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> in
return 1
return 0
-
def add_as_global_object(eduvpn: EduVPN) -> bool:
"""Add the provided parameter to the global objects lists so we can call the callback
diff --git a/wrappers/python/eduvpn_common/server.py b/wrappers/python/eduvpn_common/server.py
index 380623d..d10584e 100644
--- a/wrappers/python/eduvpn_common/server.py
+++ b/wrappers/python/eduvpn_common/server.py
@@ -2,7 +2,7 @@ from ctypes import CDLL, POINTER, c_void_p, cast
from datetime import datetime
from typing import List, Optional, Type
-from eduvpn_common.types import cServer, cServerLocations, cServerProfiles, cServers
+from eduvpn_common.types import cConfig, cServer, cServerLocations, cServerProfiles, cServers, cToken
class Profile:
@@ -20,6 +20,34 @@ class Profile:
def __str__(self):
return self.display_name
+class Token:
+ """The class that represents oauth Tokens
+
+ :param: access: str: The access token
+ :param: refresh: str: The refresh token
+ :param: expired: int: The expire unix time
+ """
+ def __init__(self, access: str, refresh: str, expired: int):
+ self.access = access
+ self.refresh = refresh
+ self.expires = expired
+
+
+class Config:
+ """The class that represents an OpenVPN/WireGuard config
+
+ :param: config: str: The config string
+ :param: config_type: str: The type of config, openvpn/wireguard
+ :param: tokens: Optional[Token]: The tokens
+ """
+ def __init__(self, config: str, config_type: str, tokens: Optional[Token]):
+ self.config = config
+ self.config_type = config_type
+ self.tokens = tokens
+
+ def __str__(self):
+ return self.config
+
class Profiles:
"""The class that represents a list of profiles
@@ -347,3 +375,36 @@ def get_locations(lib: CDLL, ptr: c_void_p) -> Optional[List[str]]:
lib.FreeSecureLocations(ptr)
return location_list
return None
+
+
+def get_config(lib: CDLL, ptr: c_void_p) -> Optional[Config]:
+ """Get the config from the Go library as a C structure and return a Python usable structure
+
+ :param lib: CDLL: The Go shared library
+ :param ptr: c_void_p: The C pointer to the confg structure
+
+ :meta private:
+
+ :return: The configuration if there is any
+ :rtype: Optional[Config]
+ """
+ # TODO: FREE
+ if ptr:
+ config = cast(ptr, POINTER(cConfig)).contents
+ cfg = config.config.decode("utf-8")
+ cfg_type = config.config_type.decode("utf-8")
+ tokens = None
+ if config.token:
+ token_struct = config.token.contents
+ tokens = Token(token_struct.access.decode("utf-8"), token_struct.refresh.decode("utf-8"), token_struct.expired)
+
+ config_class = Config(cfg, cfg_type, tokens)
+ lib.FreeConfig(ptr)
+ return config_class
+ return None
+
+def encode_tokens(arg: Optional[Token]) -> cToken:
+ if arg is None:
+ return cToken("".encode("utf-8"), "".encode("utf-8"), 0)
+ return cToken(arg.access.encode("utf-8"), arg.refresh.encode("utf-8"), arg.expires)
+
diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py
index 7e3ce9a..e4f8e26 100644
--- a/wrappers/python/eduvpn_common/types.py
+++ b/wrappers/python/eduvpn_common/types.py
@@ -15,6 +15,24 @@ from typing import Any, Callable, Iterator, List, Optional, Tuple
from eduvpn_common.error import ErrorLevel, WrappedError
+class cToken(Structure):
+ """The C type that represents the Token as forwarded to the Go library
+
+ :meta private:
+ """
+ _fields_ = [
+ ("access", c_char_p),
+ ("refresh", c_char_p),
+ ("expired", c_ulonglong),
+ ]
+
+
+class cConfig(Structure):
+ """The C type that represents the data that gets by the Go library returned when a config is obtained
+
+ :meta private:
+ """
+ _fields_ = [("config", c_char_p), ("config_type", c_char_p), ("token", POINTER(cToken))]
class cError(Structure):
"""The C type that represents the Error as returned by the Go library
@@ -156,19 +174,10 @@ class DataError(Structure):
_fields_ = [("data", c_void_p), ("error", c_void_p)]
-class ConfigError(Structure):
- """The C type that represents the data that gets by the Go library returned when a config is obtained
-
- :meta private:
- """
- _fields_ = [("config", c_void_p), ("config_type", c_void_p), ("error", c_void_p)]
-
-
# The type for a Go state change callback
VPNStateChange = CFUNCTYPE(c_int, c_char_p, c_int, c_int, c_void_p)
ReadRxBytes = CFUNCTYPE(c_ulonglong)
-
def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]:
"""Encode the arguments ready to be used by the Go library
@@ -182,8 +191,11 @@ def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]:
"""
for arg, t in zip(args, types):
# c_char_p needs the str to be encoded to bytes
- if t is c_char_p:
- arg = arg.encode("utf-8")
+ encode_map = {
+ c_char_p: lambda x: x.encode("utf-8"),
+ }
+ if t in encode_map:
+ arg = encode_map[t](arg)
yield arg
@@ -201,7 +213,6 @@ def decode_res(res: Any) -> Any:
c_int: get_bool,
c_void_p: get_error,
DataError: get_data_error,
- ConfigError: get_config_error,
}
return decode_map.get(res, lambda lib, x: x)
@@ -268,25 +279,6 @@ def get_error(lib: CDLL, ptr: c_void_p) -> Optional[WrappedError]:
return wrapped
-def get_config_error(
- lib: CDLL, config_error: ConfigError
-) -> Tuple[str, str, Optional[WrappedError]]:
- """Convert a C config structure to a Python usable config structure
-
- :param lib: CDLL: The Go shared library
- :param config_error: ConfigError: The config error structure
-
- :meta private:
-
- :return: The configuration, configuration type ('openvpn'/'wireguard') and an optional error
- :rtype: Tuple[str, str, Optional[WrappedError]]
- """
- config = get_ptr_string(lib, config_error.config)
- config_type = get_ptr_string(lib, config_error.config_type)
- err = get_error(lib, config_error.error)
- return config, config_type, err
-
-
def get_data_error(
lib: CDLL, data_error: DataError, data_conv: Callable = get_ptr_string
) -> Tuple[Any, Optional[WrappedError]]: