diff options
Diffstat (limited to 'client/server.go')
| -rw-r--r-- | client/server.go | 107 |
1 files changed, 59 insertions, 48 deletions
diff --git a/client/server.go b/client/server.go index 6802a47..0bb37a8 100644 --- a/client/server.go +++ b/client/server.go @@ -9,12 +9,14 @@ import ( "github.com/go-errors/errors" ) +type ConfigData = server.ConfigData + // getConfigAuth gets a config with authorization and authentication. // It also asks for a profile if no valid profile is found. -func (c *Client) getConfigAuth(srv server.Server, preferTCP bool) (string, string, error) { - err := c.ensureLogin(srv) +func (c *Client) getConfigAuth(srv server.Server, preferTCP bool, t oauth.Token) (*ConfigData, error) { + err := c.ensureLogin(srv, t) if err != nil { - return "", "", err + return nil, err } // TODO(jwijenbergh): Should we check if it returns false? @@ -22,13 +24,13 @@ func (c *Client) getConfigAuth(srv server.Server, preferTCP bool) (string, strin ok, err := server.HasValidProfile(srv, c.SupportsWireguard) if err != nil { - return "", "", err + return nil, err } // No valid profile, ask for one if !ok { if err = c.askProfile(srv); err != nil { - return "", "", err + return nil, err } } @@ -38,29 +40,28 @@ func (c *Client) getConfigAuth(srv server.Server, preferTCP bool) (string, strin // retryConfigAuth retries the getConfigAuth function if the tokens are invalid. // If OAuth is cancelled, it makes sure that we only forward the error as additional info. -func (c *Client) retryConfigAuth(srv server.Server, preferTCP bool) (string, string, error) { - cfg, cfgType, err := c.getConfigAuth(srv, preferTCP) +func (c *Client) retryConfigAuth(srv server.Server, preferTCP bool, t oauth.Token) (*ConfigData, error) { + cfg, err := c.getConfigAuth(srv, preferTCP, t) if err == nil { - return cfg, cfgType, nil - } - if err != nil { - // Only retry if the error is that the tokens are invalid - tErr := &oauth.TokensInvalidError{} - if errors.As(err, &tErr) { - cfg, cfgType, err = c.getConfigAuth(srv, preferTCP) - if err == nil { - return cfg, cfgType, nil - } + return cfg, nil + } + // Only retry if the error is that the tokens are invalid + tErr := &oauth.TokensInvalidError{} + if errors.As(err, &tErr) { + // TODO: Is passing empty tokens correct here? + cfg, err = c.getConfigAuth(srv, preferTCP, oauth.Token{}) + if err == nil { + return cfg, nil } - c.goBackInternal() } - return "", "", err + c.goBackInternal() + return nil, err } // getConfig gets an OpenVPN/WireGuard configuration by contacting the server, moving the FSM towards the DISCONNECTED state and then saving the local configuration file. -func (c *Client) getConfig(srv server.Server, preferTCP bool) (string, string, error) { +func (c *Client) getConfig(srv server.Server, preferTCP bool, t oauth.Token) (*ConfigData, error) { if c.InFSMState(StateDeregistered) { - return "", "", errors.Errorf("getConfig attempt in '%v'", StateDeregistered) + return nil, errors.Errorf("getConfig attempt in '%v'", StateDeregistered) } // Refresh the server endpoints @@ -70,14 +71,14 @@ func (c *Client) getConfig(srv server.Server, preferTCP bool) (string, string, e c.Logger.Warningf("failed to refresh server endpoints: %v", err) } - cfg, cfgType, err := c.retryConfigAuth(srv, preferTCP) + cfg, err := c.retryConfigAuth(srv, preferTCP, t) if err != nil { - return "", "", err + return nil, err } srv1, err := c.Servers.GetCurrentServer() if err != nil { - return "", "", err + return nil, err } // Signal the server display info @@ -91,7 +92,7 @@ func (c *Client) getConfig(srv server.Server, preferTCP bool) (string, string, e err.Error(), err.(*errors.Error).ErrorStack()) } - return cfg, cfgType, nil + return cfg, nil } // SetSecureLocation sets the location for the current secure location server. countryCode is the secure location to be chosen. @@ -227,7 +228,7 @@ func (c *Client) AddInstituteServer(url string) (srv server.Server, err error) { c.FSM.GoTransition(StateChosenServer) // Authorize it - if err = c.ensureLogin(srv); err != nil { + if err = c.ensureLogin(srv, oauth.Token{}); err != nil { // Removing is best effort _ = c.RemoveInstituteAccess(url) return nil, err @@ -285,7 +286,7 @@ func (c *Client) AddSecureInternetHomeServer(orgID string) (srv server.Server, e c.FSM.GoTransition(StateChosenServer) // Authorize it - if err = c.ensureLogin(srv); err != nil { + if err = c.ensureLogin(srv, oauth.Token{}); err != nil { // Removing is best effort _ = c.RemoveSecureInternet() return nil, err @@ -331,7 +332,7 @@ func (c *Client) AddCustomServer(url string) (srv server.Server, err error) { c.FSM.GoTransition(StateChosenServer) // Authorize it - if err = c.ensureLogin(srv); err != nil { + if err = c.ensureLogin(srv, oauth.Token{}); err != nil { // removing is best effort _ = c.RemoveCustomServer(url) return nil, err @@ -344,7 +345,7 @@ func (c *Client) AddCustomServer(url string) (srv server.Server, err error) { // GetConfigInstituteAccess gets a configuration for an Institute Access Server. // It ensures that the Institute Access Server exists by creating or using an existing one with the url. // `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool) (cfg string, cfgType string, err error) { +func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool, t oauth.Token) (cfg *ConfigData, err error) { defer func() { if err != nil { c.logError(err) @@ -353,7 +354,7 @@ func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool) (cfg strin // Not supported with Let's Connect! if c.isLetsConnect() { - return "", "", errors.Errorf("discovery with Let's Connect is not supported") + return nil, errors.Errorf("discovery with Let's Connect is not supported") } c.FSM.GoTransition(StateLoadingServer) @@ -362,28 +363,28 @@ func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool) (cfg strin var srv *server.InstituteAccessServer if srv, err = c.Servers.GetInstituteAccess(url); err != nil { c.goBackInternal() - return "", "", err + return nil, err } // Set the server as the current if err = c.Servers.SetInstituteAccess(srv); err != nil { - return "", "", err + return nil, err } // The server has now been chosen c.FSM.GoTransition(StateChosenServer) - if cfg, cfgType, err = c.getConfig(srv, preferTCP); err != nil { + if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { c.goBackInternal() } - return cfg, cfgType, err + return cfg, err } // GetConfigSecureInternet gets a configuration for a Secure Internet Server. // It ensures that the Secure Internet Server exists by creating or using an existing one with the orgID. // `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool) (cfg string, cfgType string, err error) { +func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool, t oauth.Token) (cfg *ConfigData, err error) { defer func() { if err != nil { c.logError(err) @@ -392,7 +393,7 @@ func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool) (cfg stri // Not supported with Let's Connect! if c.isLetsConnect() { - return "", "", errors.Errorf("discovery with Let's Connect is not supported") + return nil, errors.Errorf("discovery with Let's Connect is not supported") } c.FSM.GoTransition(StateLoadingServer) @@ -401,27 +402,27 @@ func (c *Client) GetConfigSecureInternet(orgID string, preferTCP bool) (cfg stri var srv *server.SecureInternetHomeServer if srv, err = c.Servers.GetSecureInternetHomeServer(); err != nil { c.goBackInternal() - return "", "", err + return nil, err } // Set the server as the current if err = c.Servers.SetSecureInternet(srv); err != nil { - return "", "", err + return nil, err } c.FSM.GoTransition(StateChosenServer) - if cfg, cfgType, err = c.getConfig(srv, preferTCP); err != nil { + if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { c.goBackInternal() } - return cfg, cfgType, err + return cfg, err } // GetConfigCustomServer gets a configuration for a Custom Server. // It ensures that the Custom Server exists by creating or using an existing one with the url. // `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (c *Client) GetConfigCustomServer(url string, preferTCP bool) (cfg string, cfgType string, err error) { +func (c *Client) GetConfigCustomServer(url string, preferTCP bool, t oauth.Token) (cfg *ConfigData, err error) { defer func() { if err != nil { c.logError(err) @@ -429,7 +430,7 @@ func (c *Client) GetConfigCustomServer(url string, preferTCP bool) (cfg string, }() if url, err = util.EnsureValidURL(url); err != nil { - return "", "", err + return nil, err } c.FSM.GoTransition(StateLoadingServer) @@ -438,21 +439,21 @@ func (c *Client) GetConfigCustomServer(url string, preferTCP bool) (cfg string, var srv *server.InstituteAccessServer if srv, err = c.Servers.GetCustomServer(url); err != nil { c.goBackInternal() - return "", "", err + return nil, err } // Set the server as the current if err = c.Servers.SetCustomServer(srv); err != nil { - return "", "", err + return nil, err } c.FSM.GoTransition(StateChosenServer) - if cfg, cfgType, err = c.getConfig(srv, preferTCP); err != nil { + if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { c.goBackInternal() } - return cfg, cfgType, err + return cfg, err } // askSecureLocation asks the user to choose a Secure Internet location by moving the FSM to the STATE_ASK_LOCATION state. @@ -512,7 +513,7 @@ func (c *Client) RenewSession() (err error) { } server.MarkTokensForRenew(srv) - return c.ensureLogin(srv) + return c.ensureLogin(srv, oauth.Token{}) } // ShouldRenewButton returns true if the renew button should be shown @@ -536,7 +537,7 @@ func (c *Client) ShouldRenewButton() bool { // ensureLogin logs the user back in if needed. // It runs the FSM transitions to ask for user input. -func (c *Client) ensureLogin(srv server.Server) (err error) { +func (c *Client) ensureLogin(srv server.Server, ct oauth.Token) (err error) { // Relogin with oauth // This moves the state to authorized if !server.NeedsRelogin(srv) { @@ -545,6 +546,16 @@ func (c *Client) ensureLogin(srv server.Server) (err error) { return nil } + // Try again but update the tokens using the client provided tokens + server.UpdateTokens(srv, ct) + if !server.NeedsRelogin(srv) { + // OAuth was valid, ensure we are in the authorized state + c.FSM.GoTransition(StateAuthorized) + return nil + } + + // Tokens are not valid or the client gave an error when updating tokens + // Otherwise, do the OAuth exchange var url string if url, err = server.OAuthURL(srv, c.Name); err != nil { return err |
