diff options
Diffstat (limited to 'client/client.go')
| -rw-r--r-- | client/client.go | 864 |
1 files changed, 243 insertions, 621 deletions
diff --git a/client/client.go b/client/client.go index 4550ef0..5bcfb35 100644 --- a/client/client.go +++ b/client/client.go @@ -5,101 +5,94 @@ package client import ( "context" - "strings" + "errors" "sync" "time" "github.com/eduvpn/eduvpn-common/i18nerr" + "github.com/eduvpn/eduvpn-common/internal/api" "github.com/eduvpn/eduvpn-common/internal/config" "github.com/eduvpn/eduvpn-common/internal/discovery" "github.com/eduvpn/eduvpn-common/internal/failover" "github.com/eduvpn/eduvpn-common/internal/fsm" "github.com/eduvpn/eduvpn-common/internal/http" "github.com/eduvpn/eduvpn-common/internal/log" - "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" "github.com/eduvpn/eduvpn-common/types/cookie" - discotypes "github.com/eduvpn/eduvpn-common/types/discovery" srvtypes "github.com/eduvpn/eduvpn-common/types/server" - "github.com/go-errors/errors" + "github.com/jwijenbergh/eduoauth-go" ) // Client is the main struct for the VPN client. type Client struct { // The name of the client - Name string `json:"-"` + Name string - // The chosen server - Servers server.List `json:"servers"` - - // The list of servers and organizations from disco - Discovery discovery.Discovery `json:"discovery"` + // The servers + Servers server.Servers // The fsm - FSM fsm.FSM `json:"-"` - - // The config - Config config.Config `json:"-"` + FSM fsm.FSM // Whether or not this client supports WireGuard - SupportsWireguard bool `json:"-"` + SupportsWireguard bool // Whether to enable debugging - Debug bool `json:"-"` + Debug bool // TokenSetter sets the tokens in the client - TokenSetter func(srv srvtypes.Current, tok srvtypes.Tokens) `json:"-"` + TokenSetter func(sid string, stype srvtypes.Type, tok srvtypes.Tokens) // TokenGetter gets the tokens from the client - TokenGetter func(srv srvtypes.Current) *srvtypes.Tokens `json:"-"` + TokenGetter func(sid string, stype srvtypes.Type) *srvtypes.Tokens + + // tokenCacher + tokCacher TokenCacher + + // cfg is the config + cfg *config.Config mu sync.Mutex } -func (c *Client) updateTokens(srv server.Server) error { - if c.TokenGetter == nil { - return errors.New("no token getter defined") - } - pSrv, err := c.pubCurrentServer(srv) - if err != nil { - return err - } - // shouldn't happen - if pSrv == nil { - return errors.New("public server is nil when getting tokens") - } - tokens := c.TokenGetter(*pSrv) - if tokens == nil { - return errors.New("client returned nil for tokens") +func (c *Client) GettingConfig() error { + if c.FSM.InState(StateGettingConfig) { + return nil } - - server.UpdateTokens(srv, oauth.Token{ - Access: tokens.Access, - Refresh: tokens.Refresh, - ExpiredTimestamp: time.Unix(tokens.Expires, 0), - }) - - return nil + _, err := c.FSM.GoTransition(StateGettingConfig) + return err } -func (c *Client) forwardTokens(srv server.Server) error { - if c.TokenSetter == nil { - return errors.New("no token setter defined") - } - pSrv, err := c.pubCurrentServer(srv) +func (c *Client) InvalidProfile(ctx context.Context, srv *server.Server) (string, error) { + // TODO: should this have profiles as a parameter + ck := cookie.NewWithContext(ctx) + + prfs, err := srv.Profiles(ctx) if err != nil { - return err + return "", err } - if pSrv == nil { - return errors.New("public server is nil when updating tokens") + if !c.SupportsWireguard { + prfs = prfs.FilterWireGuard() } - o := srv.OAuth() - if o == nil { - return errors.New("oauth was nil when forwarding tokens") + // we are guaranteed to have profiles > 0 (even after filtering) + // because internally this callback is only triggered if there is a choice to make + + errChan := make(chan error) + go func() { + err := c.FSM.GoTransitionRequired(StateAskProfile, &srvtypes.RequiredAskTransition{ + C: ck, + Data: prfs.Public(), + }) + if err != nil { + errChan <- err + } + }() + pID, err := ck.Receive(errChan) + if err != nil { + return "", err } - t := o.Token() - c.TokenSetter(*pSrv, t.Public()) - return nil + + return pID, nil } func (c *Client) goTransition(id fsm.StateID) error { @@ -157,91 +150,100 @@ func New(name string, version string, directory string, stateCallback func(FSMSt // Debug only if given c.Debug = debug - // Initialize the Config - c.Config.Init(directory, "state") - - // Try to load the previous configuration - if c.Config.Load(&c) != nil { - // This error can be safely ignored, as when the config does not load, the struct will not be filled - log.Logger.Infof("Previous configuration not found") - } + c.cfg = config.NewFromDirectory(directory) + // set the servers + c.Servers = server.NewServers(c.Name, c, c.SupportsWireguard, c.cfg.V2) return c, nil } -// Registering means updating the FSM to get to the initial state correctly -func (c *Client) Register() error { - if !c.FSM.InState(StateDeregistered) { - return i18nerr.NewInternal("The client tried to re-initialize without deregistering first") +func (c *Client) TriggerAuth(ctx context.Context, url string, wait bool) (string, error) { + // Get a reply from the client + if wait { + ck := cookie.NewWithContext(ctx) + errChan := make(chan error) + go func() { + err := c.FSM.GoTransitionRequired(StateOAuthStarted, &srvtypes.RequiredAskTransition{ + C: ck, + Data: url, + }) + if err != nil { + errChan <- err + } + }() + g, err := ck.Receive(errChan) + if err != nil { + return "", err + } + return g, nil } - err := c.goTransition(StateNoServer) + // Otherwise do normal authorization (desktop clients) + err := c.FSM.GoTransitionRequired(StateOAuthStarted, url) if err != nil { - return err + return "", err } - return nil + return "", nil } -// SaveState saves the internal state to the config -func (c *Client) SaveState() { - log.Logger.Debugf("saving state configuration....") - // Save the config - if err := c.Config.Save(&c); err != nil { - log.Logger.Infof("failed saving state configuration: '%v'", err) +func (c *Client) AuthDone(id string, t srvtypes.Type) { + srv, err := c.Servers.GetServer(id, t) + if err == nil { + srv.LastAuthorizeTime = time.Now() + } + // TODO: Should this log anything if it fails? + // unhandled transition? + _, err = c.FSM.GoTransition(StateMain) + if err != nil { + log.Logger.Debugf("unhandled auth done main transition: %v", err) } } -// Deregister 'deregisters' the client, meaning saving the log file and the config and emptying out the client struct. -func (c *Client) Deregister() { - // First of all let's transition the state machine - _ = c.goTransition(StateDeregistered) - - // SaveState saves the configuration - c.SaveState() - - // Close the log file - _ = log.Logger.Close() - - // Empty out the state - *c = Client{} -} - -// DiscoOrganizations gets the organizations list from the discovery server -// If the list cannot be retrieved an error is returned. -// If this is the case then a previous version of the list is returned if there is any. -// This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#organization-list. -func (c *Client) DiscoOrganizations(ck *cookie.Cookie) (orgs *discotypes.Organizations, err error) { - // Not supported with Let's Connect! & govVPN - if !c.hasDiscovery() { - return nil, i18nerr.NewInternal("Server/organization discovery with this client ID is not supported") +func (c *Client) TokensUpdated(id string, t srvtypes.Type, tok eduoauth.Token) { + if tok.Access == "" { + return + } + // Set the memory + err := c.tokCacher.Set(id, t, tok) + if err != nil { + log.Logger.Warningf("failed to set tokens into cache with error: %v", err) } - // Mark organizations as expired if we have not set an organization yet - if !c.Servers.HasSecureInternet() { - c.Discovery.MarkOrganizationsExpired() + if c.TokenSetter == nil { + return } + // Update the client + c.TokenSetter(id, t, srvtypes.Tokens{ + Access: tok.Access, + Refresh: tok.Refresh, + Expires: tok.ExpiredTimestamp.Unix(), + }) +} - orgs, err = c.Discovery.Organizations(ck.Context()) +// Registering means updating the FSM to get to the initial state correctly +func (c *Client) Register() error { + err := c.goTransition(StateMain) if err != nil { - err = i18nerr.Wrap(err, "An error occurred after getting the discovery files for the list of organizations") + return err } - return + return nil } -// DiscoServers gets the servers list from the discovery server -// If the list cannot be retrieved an error is returned. -// If this is the case then a previous version of the list is returned if there is any. -// This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#server-list. -func (c *Client) DiscoServers(ck *cookie.Cookie) (dss *discotypes.Servers, err error) { - // Not supported with Let's Connect! & govVPN - if !c.hasDiscovery() { - return nil, i18nerr.NewInternal("Server/organization discovery with this client ID is not supported") - } +// Deregister 'deregisters' the client, meaning saving the log file and the config and emptying out the client struct. +func (c *Client) Deregister() { + // save the config + c.TrySave() - dss, err = c.Discovery.Servers(ck.Context()) + // Move the state machine back + _, err := c.FSM.GoTransition(StateDeregistered) if err != nil { - err = i18nerr.Wrap(err, "An error occurred after getting the discovery files for the list of servers") + log.Logger.Debugf("failed deregistered transition: %v", err) } - return + + // Close the log file + _ = log.Logger.Close() + + // Empty out the state + *c = Client{} } // ExpiryTimes returns the different Unix timestamps regarding expiry @@ -250,34 +252,21 @@ func (c *Client) DiscoServers(ck *cookie.Cookie) (dss *discotypes.Servers, err e // - The list of times where notifications should be shown // These times are reset when the VPN gets disconnected func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { - // Get current expiry time - srv, err := c.Servers.Current() - if err != nil { - return nil, i18nerr.Wrap(err, "The current server could not be found when getting it for expiry") - } - b, err := srv.Base() + srv, err := c.Servers.CurrentServer() if err != nil { - return nil, err - } - - if b.StartTime.IsZero() { - return nil, i18nerr.New("No start time is defined for this server") + return nil, i18nerr.Wrap(err, "The current server was not found when getting the VPN expiration date") } - - bT := b.RenewButtonTime() - cT := b.CountdownTime() - nT := b.NotificationTimes() return &srvtypes.Expiry{ - StartTime: b.StartTime.Unix(), - EndTime: b.EndTime.Unix(), - ButtonTime: bT, - CountdownTime: cT, - NotificationTimes: nT, + StartTime: srv.LastAuthorizeTime.Unix(), + EndTime: srv.ExpireTime.Unix(), + ButtonTime: server.RenewButtonTime(srv.LastAuthorizeTime, srv.ExpireTime), + CountdownTime: server.CountdownTime(srv.LastAuthorizeTime, srv.ExpireTime), + NotificationTimes: server.NotificationTimes(srv.LastAuthorizeTime, srv.ExpireTime), }, nil } -func (c *Client) locationCallback(ck *cookie.Cookie) error { - locs := c.Discovery.SecureLocationList() +func (c *Client) locationCallback(ck *cookie.Cookie, orgID string) error { + locs := c.cfg.Discovery().SecureLocationList() errChan := make(chan error) go func() { err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{ @@ -292,139 +281,19 @@ func (c *Client) locationCallback(ck *cookie.Cookie) error { if err != nil { return err } - err = c.SetSecureLocation(ck, loc) - if err != nil { - return err - } - err = c.goTransition(StateChosenLocation) + srv, err := c.Servers.GetServer(orgID, srvtypes.TypeSecureInternet) if err != nil { return err } + srv.CountryCode = loc return nil } -func (c *Client) loginCallback(ck *cookie.Cookie, srv server.Server) error { - // get a custom redirect - cr := CustomRedirect(c.Name) - url, err := server.OAuthURL(srv, c.Name, cr) - if err != nil { - return err - } - authCodeURI := "" - if cr != "" { - errChan := make(chan error) - go func() { - err := c.FSM.GoTransitionRequired(StateOAuthStarted, &srvtypes.RequiredAskTransition{ - C: ck, - Data: url, - }) - if err != nil { - errChan <- err - } - }() - g, err := ck.Receive(errChan) - if err != nil { - return err - } - authCodeURI = g - } else { - err = c.FSM.GoTransitionRequired(StateOAuthStarted, url) - if err != nil { - return err - } - } - err = server.OAuthExchange(ck.Context(), srv, authCodeURI) - if err != nil { - return err - } - return nil -} - -func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool, startup bool) error { - // location - if srv.NeedsLocation() { - if startup { - return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but no secure internet location is found. Please manually connect again", server.Name(srv)) - } - err := c.locationCallback(ck) - if err != nil { - return i18nerr.Wrap(err, "The secure internet location could not be set") - } - } - - err := c.goTransition(StateChosenServer) - if err != nil { - log.Logger.Debugf("optional chosen server transition not possible: %v", err) - } - // oauth - // TODO: This should be ck.Context() - // But needsrelogin needs a rewrite to support this properly - - // first make sure we get the most up to date tokens from the client - err = c.updateTokens(srv) - if err != nil { - log.Logger.Debugf("failed to get tokens from client: %v", err) - } - if server.NeedsRelogin(context.Background(), srv) || forceauth { - if startup { - return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but you need to authorizate again. Please manually connect again", server.Name(srv)) - } - // mark organizations as expired if the server is a secure internet server - b, berr := srv.Base() - if berr == nil && b.Type == srvtypes.TypeSecureInternet { - c.Discovery.MarkOrganizationsExpired() - } - err := c.loginCallback(ck, srv) - if err != nil { - return i18nerr.Wrap(err, "The authorization procedure failed to complete") - } - } - err = c.goTransition(StateAuthorized) - if err != nil { - return err - } - - return nil -} - -func (c *Client) profileCallback(ck *cookie.Cookie, srv server.Server, startup bool) error { - vp, err := server.HasValidProfile(ck.Context(), srv, c.SupportsWireguard) - if err != nil { - log.Logger.Warningf("failed to determine whether the current protocol is valid with error: %v", err) - return err - } - if !vp { - if startup { - return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but no valid profiles were found. Please manually connect again", server.Name(srv)) - } - vps, err := server.ValidProfiles(srv, c.SupportsWireguard) - if err != nil { - return i18nerr.Wrapf(err, "No suitable profiles could be found") - } - errChan := make(chan error) - go func() { - err := c.FSM.GoTransitionRequired(StateAskProfile, &srvtypes.RequiredAskTransition{ - C: ck, - Data: vps.Public(), - }) - if err != nil { - errChan <- err - } - }() - pID, err := ck.Receive(errChan) - if err != nil { - return i18nerr.Wrapf(err, "Profile with ID: '%s' could not be set", pID) - } - err = server.Profile(srv, pID) - if err != nil { - return i18nerr.Wrapf(err, "Profile with ID: '%s' could not be obtained from the server", pID) - } - } - err = c.goTransition(StateChosenProfile) +func (c *Client) TrySave() { + err := c.cfg.Save() if err != nil { - return err + log.Logger.Warningf("failed to save configuration: %v", err) } - return nil } // AddServer adds a server with identifier and type @@ -435,463 +304,233 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. // We add the server because we can then obtain it in other callback functions previousState := c.FSM.Current defer func() { - if err != nil { - _ = c.RemoveServer(identifier, _type) //nolint:errcheck - } else { - c.SaveState() - } // If we must run callbacks, go to the previous state if we're not in it if !ni && !c.FSM.InState(previousState) { c.FSM.GoTransition(previousState) //nolint:errcheck } + if err == nil { + c.TrySave() + } }() if !ni { - err = c.goTransition(StateLoadingServer) + err = c.goTransition(StateAddingServer) // this is already wrapped in an UI error if err != nil { return err } } - if _type != srvtypes.TypeSecureInternet { + // Convert to an identifier identifier, err = http.EnsureValidURL(identifier, true) if err != nil { - return i18nerr.Wrap(err, "The identifier that was passed to the library is incorrect") + // TODO: wrap error + return err } } - var srv server.Server - switch _type { case srvtypes.TypeInstituteAccess: - dSrv, err := c.Discovery.ServerByURL(identifier, "institute_access") - if err != nil { - return i18nerr.Wrapf(err, "Could not retrieve institute access server with URL: '%s' from discovery", identifier) - } - srv, err = c.Servers.AddInstituteAccess(ck.Context(), c.Name ,dSrv) + _, err = c.Servers.AddInstitute(ck.Context(), c.cfg.Discovery(), identifier, ni) if err != nil { return i18nerr.Wrapf(err, "The institute access server with URL: '%s' could not be added", identifier) } case srvtypes.TypeSecureInternet: - dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier) - if err != nil { - // We mark the organizations as expired because we got an error - // Note that in the docs it states that it only should happen when the Org ID doesn't exist - // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found - c.Discovery.MarkOrganizationsExpired() - return i18nerr.Wrapf(err, "The secure internet server with organisation ID: '%s' could not be retrieved from discovery", identifier) - } - srv, err = c.Servers.AddSecureInternet(ck.Context(), c.Name, dOrg, dSrv) + _, err = c.Servers.AddSecure(ck.Context(), c.cfg.Discovery(), identifier, ni) if err != nil { return i18nerr.Wrapf(err, "The secure internet server with organisation ID: '%s' could not be added", identifier) } case srvtypes.TypeCustom: - srv, err = c.Servers.AddCustom(ck.Context(), c.Name, identifier) + _, err = c.Servers.AddCustom(ck.Context(), identifier, ni) if err != nil { return i18nerr.Wrapf(err, "The custom server with URL: '%s' could not be added", identifier) } default: return i18nerr.NewInternalf("Server type: '%v' is not valid to be added", _type) } - - // if we are non interactive, we run no callbacks - if ni { - return nil - } - - // callbacks - err = c.callbacks(ck, srv, false, false) - // error is already UI wrapped - if err != nil { - return err - } - terr := c.forwardTokens(srv) - if terr != nil { - log.Logger.Debugf("failed to forward tokens after adding: %v", terr) - } return nil } -func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool, startup bool) (cfg *srvtypes.Configuration, err error) { - // do the callbacks to ensure valid profile, location and authorization - err = c.callbacks(ck, srv, forceAuth, startup) - if err != nil { - return nil, err - } - - err = c.goTransition(StateRequestConfig) - if err != nil { - return nil, err - } - - err = c.profileCallback(ck, srv, startup) - if err != nil { - return nil, err - } - - cfgS, err := server.Config(ck.Context(), srv, c.SupportsWireguard, pTCP) - if err != nil { - return nil, i18nerr.Wrap(err, "The VPN configuration could not be obtained") +func (c *Client) convertIdentifier(identifier string, t srvtypes.Type) (string, error) { + // assume secure internet identifiers are always valid as we can't really assume they are valid urls (+ always https) + if t == srvtypes.TypeSecureInternet { + return identifier, nil } - p, err := server.CurrentProfile(srv) + // Convert to an identifier, this also converts the scheme to HTTPS + identifier, err := http.EnsureValidURL(identifier, true) if err != nil { - return nil, i18nerr.Wrap(err, "The current profile could not be found") + return "", i18nerr.Wrapf(err, "input: '%s' is not a valid URL", identifier) } - pcfg := cfgS.Public(p.DefaultGateway) - return &pcfg, nil -} - -func (c *Client) server(identifier string, _type srvtypes.Type) (srv server.Server, setter func(server.Server) error, err error) { - switch _type { - case srvtypes.TypeInstituteAccess: - srv, err = c.Servers.InstituteAccess(identifier) - setter = c.Servers.SetInstituteAccess - case srvtypes.TypeSecureInternet: - srv, err = c.Servers.SecureInternet(identifier) - setter = c.Servers.SetSecureInternet - case srvtypes.TypeCustom: - srv, err = c.Servers.CustomServer(identifier) - setter = c.Servers.SetCustom - default: - return nil, nil, i18nerr.NewInternalf("Not a valid server type: %v", _type) - } - return srv, setter, err + return identifier, nil } // GetConfig gets a VPN configuration -func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool, startup bool) (cfg *srvtypes.Configuration, err error) { +func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool, startup bool) (*srvtypes.Configuration, error) { c.mu.Lock() defer c.mu.Unlock() previousState := c.FSM.Current + var err error + defer func() { if err == nil { c.FSM.GoTransition(StateGotConfig) //nolint:errcheck - c.SaveState() } else if !c.FSM.InState(previousState) { // go back to the previous state if an error occurred c.FSM.GoTransition(previousState) //nolint:errcheck } }() - if _type != srvtypes.TypeSecureInternet { - identifier, err = http.EnsureValidURL(identifier, true) - if err != nil { - return nil, i18nerr.Wrapf(err, "Identifier: '%s' for server with type: '%d' is not valid", identifier, _type) - } - } - err = c.goTransition(StateLoadingServer) - if err != nil { - return nil, err - } - srv, set, err := c.server(identifier, _type) - if err != nil { - return nil, err - } - // refresh the server endpoints - err = srv.RefreshEndpoints(ck.Context(), &c.Discovery) - // If we get a canceled error, return that, otherwise just log the error + identifier, err = c.convertIdentifier(identifier, _type) if err != nil { - if errors.Is(err, context.Canceled) { - return nil, i18nerr.Wrap(err, "The operation for getting a VPN configuration was canceled") - } - - log.Logger.Warningf("failed to refresh server endpoints: %v", err) + return nil, i18nerr.Wrapf(err, "Server identifier: '%s', is not valid when getting a VPN configuration", identifier) } - - // get a config and retry with authorization if expired - cfg, err = c.config(ck, srv, pTCP, false, startup) - tErr := &oauth.TokensInvalidError{} - if err != nil && errors.As(err, &tErr) { - log.Logger.Debugf("the tokens were invalid, trying again...") - cfg, err = c.config(ck, srv, pTCP, true, startup) - } - - // tokens might be updated, forward them - defer func() { - terr := c.forwardTokens(srv) - if terr != nil { - log.Logger.Debugf("failed to forward tokens after get config: %v", terr) - } - }() - - // still an error, return nil with the error + err = c.GettingConfig() if err != nil { - return nil, err - } - - // set the current server - if err = set(srv); err != nil { - return nil, i18nerr.Wrapf(err, "Failed to set the server with identifier: '%s' as the current", identifier) + log.Logger.Debugf("failed getting config transition: %v", err) } - return cfg, nil -} - -func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) { - if _type != srvtypes.TypeSecureInternet { - identifier, err = http.EnsureValidURL(identifier, true) - if err != nil { - return i18nerr.Wrapf(err, "Identifier: '%s' for server with type: '%d' is not valid for removal", identifier, _type) - } + tok, err := c.retrieveTokens(identifier, _type) + if err != nil { + log.Logger.Debugf("no tokens found for server: '%s', with error: '%v'", identifier, err) } - // miscellaneous error - var mErr error + var srv *server.Server switch _type { case srvtypes.TypeInstituteAccess: - mErr = c.Servers.RemoveInstituteAccess(identifier) + srv, err = c.Servers.GetInstitute(ck.Context(), identifier, c.cfg.Discovery(), tok, startup) case srvtypes.TypeSecureInternet: - mErr = c.Servers.RemoveSecureInternet(identifier) + srv, err = c.Servers.GetSecure(ck.Context(), identifier, c.cfg.Discovery(), tok, startup) + + var cErr *discovery.CountryNotFoundError + if errors.As(err, &cErr) { + err = c.locationCallback(ck, identifier) + if err == nil { + srv, err = c.Servers.GetSecure(ck.Context(), identifier, c.cfg.Discovery(), tok, startup) + } + } case srvtypes.TypeCustom: - mErr = c.Servers.RemoveCustom(identifier) + srv, err = c.Servers.GetCustom(ck.Context(), identifier, tok, startup) default: - return i18nerr.NewInternalf("Not a valid server type: %v", _type) + err = i18nerr.NewInternalf("Server type: '%v' is not valid to get a config for", _type) } - if mErr != nil { - log.Logger.Debugf("failed to remove server with identifier: '%s' and type: '%d', error: %v", identifier, _type, mErr) + if err != nil { + if startup { + if errors.Is(err, api.ErrAuthorizeDisabled) { + return nil, i18nerr.Newf("The client tried to autoconnect to the VPN server: '%s', but you need to authorizate again. Please manually connect again", identifier) + } + return nil, i18nerr.Wrapf(err, "The client tried to autoconnect to the VPN server: '%s', but the operation failed to complete", identifier) + } + return nil, i18nerr.Wrapf(err, "Server: '%s' could not be obtained", identifier) } - c.SaveState() - return nil -} -func (c *Client) CurrentServer() (*srvtypes.Current, error) { - srv, err := c.Servers.Current() + cfg, err := c.Servers.ConnectWithCallbacks(ck.Context(), srv, pTCP) if err != nil { - return nil, err + return nil, i18nerr.Wrapf(err, "No VPN configuration for server: '%s' could be obtained", identifier) } - return c.pubCurrentServer(srv) + return cfg, nil } -func (c *Client) pubCurrentServer(srv server.Server) (*srvtypes.Current, error) { - b, err := srv.Base() +func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) { + identifier, err = c.convertIdentifier(identifier, _type) if err != nil { - return nil, err + return i18nerr.Wrapf(err, "Server identifier: '%s', is not valid when removing the server", identifier) } - pub, err := srv.Public() + err = c.Servers.Remove(identifier, _type) if err != nil { - return nil, err - } - switch t := pub.(type) { - case *srvtypes.Server: - if b.Type == srvtypes.TypeInstituteAccess { - return &srvtypes.Current{ - Institute: &srvtypes.Institute{ - Server: *t, - SupportContacts: b.SupportContact, - // TODO: delisted - Delisted: false, - }, - Type: srvtypes.TypeInstituteAccess, - }, nil - } - return &srvtypes.Current{ - Custom: t, - Type: srvtypes.TypeCustom, - }, nil - case *srvtypes.SecureInternet: - t.SupportContacts = b.SupportContact - t.Locations = c.Discovery.SecureLocationList() - return &srvtypes.Current{ - SecureInternet: t, - Type: srvtypes.TypeSecureInternet, - }, nil - default: - panic("unknown type") + return i18nerr.Wrapf(err, "The server: '%s' could not be removed", identifier) } + return nil } -// TODO: This should not rely on interface{} -func (c *Client) pubServer(srv server.Server) (interface{}, error) { - pub, err := srv.Public() +func (c *Client) CurrentServer() (*srvtypes.Current, error) { + curr, err := c.Servers.PublicCurrent(c.cfg.Discovery()) if err != nil { - return nil, err + return nil, i18nerr.Wrap(err, "The current server could not be retrieved") } - b, err := srv.Base() + return curr, nil +} + +func (c *Client) SetProfileID(pID string) error { + srv, err := c.Servers.CurrentServer() if err != nil { - return nil, err - } - switch t := pub.(type) { - case *srvtypes.Server: - if b.Type == srvtypes.TypeInstituteAccess { - return &srvtypes.Institute{ - Server: *t, - SupportContacts: b.SupportContact, - // TODO: delisted - Delisted: false, - }, nil - } - return t, nil - case *srvtypes.SecureInternet: - t.SupportContacts = b.SupportContact - t.Locations = c.Discovery.SecureLocationList() - return t, nil - default: - panic("unknown type") + return i18nerr.Wrapf(err, "Failed to set the profile ID: '%s'", pID) } + srv.Profiles.Current = pID + return nil } -func (c *Client) ServerList() (*srvtypes.List, error) { - if c.FSM.InState(StateDeregistered) { - return nil, i18nerr.NewInternal("Client is not registered") - } - var customServers []srvtypes.Server - for _, v := range c.Servers.CustomServers.Map { - if v == nil { - continue - } - p, err := c.pubServer(v) - if err != nil { - continue - } - c, ok := p.(*srvtypes.Server) - if !ok { - continue - } - customServers = append(customServers, *c) +func (c *Client) retrieveTokens(sid string, t srvtypes.Type) (*eduoauth.Token, error) { + // get from memory + tok, err := c.tokCacher.Get(sid, t) + if err == nil { + return tok, nil } - var instituteServers []srvtypes.Institute - for _, v := range c.Servers.InstituteServers.Map { - if v == nil { - continue - } - p, err := c.pubServer(v) - if err != nil { - continue - } - i, ok := p.(*srvtypes.Institute) - if !ok { - continue - } - instituteServers = append(instituteServers, *i) + if c.TokenGetter == nil { + return tok, err } - var secureInternet *srvtypes.SecureInternet - if c.Servers.HasSecureInternet() { - srv := &c.Servers.SecureInternetHomeServer - p, err := c.pubServer(srv) - if err == nil { - s, ok := p.(*srvtypes.SecureInternet) - if ok { - secureInternet = s - } - } + // get from client + gtok := c.TokenGetter(sid, t) + if gtok == nil { + return nil, errors.New("client returned nil tokens") } - return &srvtypes.List{ - Institutes: instituteServers, - SecureInternet: secureInternet, - Custom: customServers, + return &eduoauth.Token{ + Access: gtok.Access, + Refresh: gtok.Refresh, + ExpiredTimestamp: time.Unix(gtok.Expires, 0), }, nil } -func (c *Client) SetProfileID(pID string) (err error) { - srv, err := c.Servers.Current() +func (c *Client) Cleanup(ck *cookie.Cookie) error { + srv, err := c.Servers.CurrentServer() if err != nil { - return err - } - err = server.Profile(srv, pID) - if err == nil { - c.SaveState() + return i18nerr.Wrap(err, "The current server was not found when cleaning up the connection") } - return err -} - -func (c *Client) Cleanup(ck *cookie.Cookie) (err error) { - // get the current server - srv, err := c.Servers.Current() + tok, err := c.retrieveTokens(srv.T.ID, srv.T.T) if err != nil { - return i18nerr.Wrap(err, "Failed to get the current server to cleanup the connection") + return i18nerr.Wrap(err, "No OAuth tokens were found when cleaning up the connection") } - - err = srv.RefreshEndpoints(ck.Context(), &c.Discovery) - - // If we get a canceled error, return that, otherwise just log the error + auth, err := srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), tok, true) if err != nil { - if errors.Is(err, context.Canceled) { - return i18nerr.Wrap(err, "The cleanup process was canceled") - } - - log.Logger.Warningf("failed to refresh server endpoints: %v", err) - } - - - defer c.SaveState() - err = c.updateTokens(srv) - if err != nil { - log.Logger.Debugf("failed to update tokens for disconnect: %v", err) - } - err = server.Disconnect(ck.Context(), srv) - if err != nil { - return i18nerr.Wrap(err, "Failed to cleanup the VPN connection for the current server") + return i18nerr.Wrap(err, "The server was unable to be retrieved when cleaning up the connection") } - err = c.forwardTokens(srv) + err = auth.Disconnect(ck.Context()) if err != nil { - log.Logger.Debugf("failed to forward tokens after disconnect: %v", err) + return i18nerr.Wrap(err, "Failed to cleanup the VPN connection") } return nil } -func (c *Client) SetSecureLocation(ck *cookie.Cookie, countryCode string) (err error) { +func (c *Client) SetSecureLocation(orgID string, countryCode string) error { // not supported with Let's Connect! & govVPN if !c.hasDiscovery() { return i18nerr.NewInternal("Setting a secure internet location with this client ID is not supported") } - - if !c.Servers.HasSecureInternet() { - return i18nerr.Newf("No secure internet server available to set a location for") - } - - dSrv, err := c.Discovery.ServerByCountryCode(countryCode) + srv, err := c.Servers.GetServer(orgID, srvtypes.TypeSecureInternet) if err != nil { - return err - } - - err = c.Servers.SecureInternetHomeServer.Location(ck.Context(), dSrv) - if err == nil { - c.SaveState() + return i18nerr.Wrapf(err, "Failed to get the secure internet server with id: '%s' for setting a location", orgID) } - return err + srv.CountryCode = countryCode + return nil } -func (c *Client) RenewSession(ck *cookie.Cookie) (err error) { - c.mu.Lock() - defer c.mu.Unlock() - srv, err := c.Servers.Current() +func (c *Client) RenewSession(ck *cookie.Cookie) error { + // getting the current serving with nil tokens means re-authorize + srv, err := c.Servers.CurrentServer() if err != nil { - return i18nerr.Wrap(err, "Failed to get current server for renewing the session") - } - // The server has not been chosen yet, this means that we want to manually renew - // TODO: is this needed? - if !c.FSM.InState(StateLoadingServer) { - c.FSM.GoTransition(StateLoadingServer) //nolint:errcheck + return i18nerr.Wrap(err, "The current server could not be retrieved when renewing the session") } - err = srv.RefreshEndpoints(ck.Context(), &c.Discovery) - // If we get a canceled error, return that, otherwise just log the error + // getting a server with no tokens means re-authorize + _, err = srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), nil, false) if err != nil { - if errors.Is(err, context.Canceled) { - return i18nerr.Wrap(err, "The renewing process was canceled") - } - - log.Logger.Warningf("failed to refresh server endpoints: %v", err) + return i18nerr.Wrap(err, "The server was unable to be retrieved when renewing the session") } - - - // update tokens in the end - defer func() { - terr := c.forwardTokens(srv) - if terr != nil { - log.Logger.Debugf("failed to forward tokens after renew: %v", terr) - } - }() - defer c.SaveState() - // TODO: Maybe this can be deleted because we force auth now - server.MarkTokensForRenew(srv) - // run the callbacks by forcing auth - return c.callbacks(ck, srv, true, false) + return nil } func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readRxBytes func() (int64, error)) (bool, error) { f := failover.New(readRxBytes) + // get current profile d, err := f.Start(ck.Context(), gateway, mtu) if err != nil { return d, i18nerr.Wrapf(err, "Failover failed to complete with gateway: '%s' and MTU: '%d'", gateway, mtu) @@ -899,24 +538,7 @@ func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readR return d, nil } -func (c *Client) SetState(state FSMStateID) error { - c.mu.Lock() - defer c.mu.Unlock() - curr := c.FSM.Current - _, err := c.FSM.GoTransition(state) - if err != nil { - // self-transitions are only debug errors - if c.FSM.InState(state) { - log.Logger.Debugf("attempt an invalid self-transition: %s", c.FSM.GetStateName(state)) - return nil - } - return i18nerr.WrapInternalf(err, "Failed internal state transition requested by the client from: '%s' to '%s'", GetStateName(curr), GetStateName(state)) - } - return nil -} - -func (c *Client) InState(state FSMStateID) bool { - c.mu.Lock() - defer c.mu.Unlock() - return c.FSM.InState(state) +func (c *Client) ServerList() (*srvtypes.List, error) { + g := c.cfg.V2.PublicList(c.cfg.Discovery()) + return g, nil } |
