diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2023-04-12 22:55:16 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2023-09-25 09:43:37 +0200 |
| commit | a38e3e79f74e95051db7e14ae14ab817b68b725a (patch) | |
| tree | e26cab53f993d2d845020f81ee6f6f6a8a12ded1 /client | |
| parent | 4d228ba2084eb810d0cc33308893b00d1bb3eb02 (diff) | |
Refactor: Move client implementation to one file
Much easier to oversee and it forces me to keep the client type as
small as possible. This also uses the cookie for cancellation
We also no longer require tokens inside arguments. We will later
implement them with callbacks
Diffstat (limited to 'client')
| -rw-r--r-- | client/client.go | 622 | ||||
| -rw-r--r-- | client/client_test.go | 84 | ||||
| -rw-r--r-- | client/fsm.go | 48 | ||||
| -rw-r--r-- | client/server.go | 701 |
4 files changed, 513 insertions, 942 deletions
diff --git a/client/client.go b/client/client.go index 70adb71..813f6dc 100644 --- a/client/client.go +++ b/client/client.go @@ -2,9 +2,9 @@ package client import ( + "context" "fmt" "strings" - "sync" "github.com/eduvpn/eduvpn-common/internal/config" "github.com/eduvpn/eduvpn-common/internal/discovery" @@ -14,32 +14,12 @@ import ( "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" + "github.com/eduvpn/eduvpn-common/types/cookie" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - "github.com/eduvpn/eduvpn-common/types/protocol" srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" ) -type ( - // ServerBase is an alias to the internal ServerBase - // This contains the details for each server. - ServerBase = server.Base -) - -func (c *Client) logError(err error) { - // Logs the error with the same level/verbosity as the error - if c.Debug { - log.Logger.Inherit(err, fmt.Sprintf("\nwith stacktrace: %s\n", err.(*errors.Error).ErrorStack())) - } else { - log.Logger.Inherit(err, "") - } -} - -func (c *Client) isLetsConnect() bool { - // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php - return strings.HasPrefix(c.Name, "org.letsconnect-vpn.app") -} - // isAllowedClientID checks if the 'clientID' is in the list of allowed client IDs func isAllowedClientID(clientID string) bool { allowList := []string{ @@ -91,13 +71,27 @@ func userAgentName(clientID string) string { } } +func (c *Client) logError(err error) { + // Logs the error with the same level/verbosity as the error + if c.Debug { + log.Logger.Inherit(err, fmt.Sprintf("\nwith stacktrace: %s\n", err.(*errors.Error).ErrorStack())) + } else { + log.Logger.Inherit(err, "") + } +} + +func (c *Client) isLetsConnect() bool { + // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php + return strings.HasPrefix(c.Name, "org.letsconnect-vpn.app") +} + // Client is the main struct for the VPN client. type Client struct { // The name of the client Name string `json:"-"` // The chosen server - Servers server.Servers `json:"servers"` + Servers server.List `json:"servers"` // The list of servers and organizations from disco Discovery discovery.Discovery `json:"discovery"` @@ -116,9 +110,6 @@ type Client struct { // The Failover monitor for the current VPN connection Failover *failover.DroppedConMon - - locationWg sync.WaitGroup - profileWg sync.WaitGroup } // New creates a new client with the following parameters: @@ -179,7 +170,7 @@ func New(name string, version string, directory string, stateCallback func(FSMSt // Registering means updating the FSM to get to the initial state correctly func (c *Client) Register() error { - if !c.InFSMState(StateDeregistered) { + if !c.FSM.InState(StateDeregistered) { return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current) } c.FSM.GoTransition(StateNoServer) @@ -200,27 +191,11 @@ func (c *Client) Deregister() { *c = Client{} } -// askProfile asks the user for a profile by moving the FSM to the ASK_PROFILE state. -func (c *Client) askProfile(srv server.Server) error { - ps, err := server.ValidProfiles(srv, c.SupportsWireguard) - if err != nil { - return err - } - - c.profileWg.Add(1) - if err = c.FSM.GoTransitionRequired(StateAskProfile, convertProfiles(*ps)); err != nil { - return err - } - c.profileWg.Wait() - - return nil -} - // DiscoOrganizations gets the organizations list from the discovery server // If the list cannot be retrieved an error is returned. // If this is the case then a previous version of the list is returned if there is any. // This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#organization-list. -func (c *Client) DiscoOrganizations() (orgs *discotypes.Organizations, err error) { +func (c *Client) DiscoOrganizations(ck *cookie.Cookie) (orgs *discotypes.Organizations, err error) { defer func() { if err != nil { c.logError(err) @@ -237,14 +212,15 @@ func (c *Client) DiscoOrganizations() (orgs *discotypes.Organizations, err error c.Discovery.MarkOrganizationsExpired() } - return c.Discovery.Organizations() + // TODO: pass a context + return c.Discovery.Organizations(ck.Context()) } // DiscoServers gets the servers list from the discovery server // If the list cannot be retrieved an error is returned. // If this is the case then a previous version of the list is returned if there is any. // This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#server-list. -func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { +func (c *Client) DiscoServers(ck *cookie.Cookie) (dss *discotypes.Servers, err error) { defer func() { if err != nil { c.logError(err) @@ -256,7 +232,8 @@ func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { return nil, errors.Errorf("discovery with Let's Connect is not supported") } - return c.Discovery.Servers() + // TODO: pass a context + return c.Discovery.Servers(ck.Context()) } // ExpiryTimes returns the different Unix timestamps regarding expiry @@ -266,7 +243,7 @@ func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { // These times are reset when the VPN gets disconnected func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { // Get current expiry time - srv, err := c.Servers.GetCurrentServer() + srv, err := c.Servers.Current() if err != nil { c.logError(err) return nil, err @@ -293,149 +270,488 @@ func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { }, nil } -func convertProfiles(profiles server.ProfileInfo) srvtypes.Profiles { - m := make(map[string]srvtypes.Profile) - for _, p := range profiles.Info.ProfileList { - var protocols []protocol.Protocol - // loop through all protocol strings - for _, ps := range p.VPNProtoList { - protocols = append(protocols, protocol.New(ps)) - } - m[p.ID] = srvtypes.Profile{ - DisplayName: map[string]string{ - "en": p.DisplayName, - }, - Protocols: protocols, +func (c *Client) locationCallback(ck *cookie.Cookie) error { + locs := c.Discovery.SecureLocationList() + errChan := make(chan error) + go func() { + err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{ + C: ck, + Data: locs, + }) + if err != nil { + errChan <- err } + }() + loc, err := ck.Receive(errChan) + if err != nil { + return err } - return srvtypes.Profiles{Map: m, Current: profiles.Current} + err = c.SetSecureLocation(ck, loc) + if err != nil { + return err + } + t := c.FSM.GoTransition(StateChosenLocation) + if !t { + log.Logger.Warningf("transition chosen location not completed") + } + return nil } -func convertGeneric(server server.InstituteAccessServer) (*srvtypes.Server, error) { - b, err := server.Base() +func (c *Client) loginCallback(ck *cookie.Cookie, srv server.Server) error { + url, err := server.OAuthURL(srv, c.Name) if err != nil { - return nil, err + return err } - return &srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: b.URL, - Profiles: convertProfiles(b.Profiles), - }, nil + err = c.FSM.GoTransitionRequired(StateOAuthStarted, url) + if err != nil { + return err + } + err = server.OAuthExchange(ck.Context(), srv) + if err != nil { + return err + } + return nil } -// TODO: CLEAN THIS UP -func (c *Client) ServerList() (*srvtypes.List, error) { - custom := c.Servers.CustomServers - var customServers []srvtypes.Server - for _, v := range custom.Map { - if v == nil { - return nil, errors.New("found nil value in custom server map") +func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool) error { + // location + if srv.NeedsLocation() { + err := c.locationCallback(ck) + if err != nil { + return err } - conv, err := convertGeneric(*v) + } + + t := c.FSM.GoTransition(StateChosenServer) + if !t { + log.Logger.Warningf("transition not completed for chosen server") + } + // oauth + // TODO: This should be ck.Context() + // But needsrelogin needs a rewrite to support this properly + if server.NeedsRelogin(context.Background(), srv) || forceauth { + err := c.loginCallback(ck, srv) if err != nil { - return nil, errors.Errorf("failed to convert custom server for public type: %v", err) + return err } - customServers = append(customServers, *conv) } - institute := c.Servers.InstituteServers - var instituteServers []srvtypes.Institute - for _, v := range institute.Map { - if v == nil { - return nil, errors.New("found nil value in institute server map") + t = c.FSM.GoTransition(StateAuthorized) + if !t { + log.Logger.Warningf("transition authorized not completed") + } + + return nil +} + +func (c *Client) profileCallback(ck *cookie.Cookie, srv server.Server) error { + vp, err := server.HasValidProfile(ck.Context(), srv, c.SupportsWireguard) + if err != nil { + return err + } + if !vp { + b, err := srv.Base() + if err != nil { + return err } - conv, err := convertGeneric(*v) + ps := b.Profiles.Public() + errChan := make(chan error) + go func() { + err := c.FSM.GoTransitionRequired(StateAskProfile, &srvtypes.RequiredAskTransition{ + C: ck, + Data: ps, + }) + if err != nil { + errChan <- err + } + }() + pID, err := ck.Receive(errChan) if err != nil { - return nil, errors.Errorf("failed to convert institute server for public type: %v", err) + return err + } + err = server.Profile(srv, pID) + if err != nil { + return err } - instituteServers = append(instituteServers, srvtypes.Institute{ - Server: *conv, - // TODO: delisted - Delisted: false, - }) } + t := c.FSM.GoTransition(StateChosenProfile) + if !t { + log.Logger.Warningf("transition chosen profile not completed") + } + return nil +} - var secureInternet *srvtypes.SecureInternet - if c.Servers.HasSecureInternet() { - b, err := c.Servers.SecureInternetHomeServer.Base() - if err == nil { - generic := srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: b.URL, - Profiles: convertProfiles(b.Profiles), - } - cc := c.Servers.SecureInternetHomeServer.CurrentLocation - secureInternet = &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, - // TODO: delisted - Delisted: false, - } +// AddServer adds a server with identifier and type +func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.Type, ni bool) (err error) { + // If we have failed to add the server, we remove it again + // We add the server because we can then obtain it in other callback functions + defer func() { + if err != nil { + _ = c.RemoveServer(identifier, _type) //nolint:errcheck } + c.FSM.GoTransition(StateNoServer) + }() + if !ni { + if !c.FSM.InState(StateNoServer) { + return errors.Errorf("wrong state to add a server: %s", GetStateName(c.FSM.Current)) + } + t := c.FSM.GoTransition(StateLoadingServer) + if !t { + log.Logger.Warningf("transition not completed for loading server") + } } - return &srvtypes.List{ - Institutes: instituteServers, - SecureInternet: secureInternet, - Custom: customServers, - }, nil -} -// TODO: CLEAN THIS UP -func (c *Client) CurrentServer() (*srvtypes.Current, error) { - srvs := c.Servers + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return err + } + + var srv server.Server - switch srvs.IsType { - case server.InstituteAccessServerType: - curr, err := srvs.GetInstituteAccess(srvs.InstituteServers.CurrentURL) + switch _type { + case srvtypes.TypeInstituteAccess: + dSrv, err := c.Discovery.ServerByURL(identifier, "institute_access") if err != nil { - return nil, err + return err } - conv, err := convertGeneric(*curr) + srv, err = c.Servers.AddInstituteAccess(ck.Context(), dSrv) if err != nil { - return nil, err + return err } - return &srvtypes.Current{ - Institute: &srvtypes.Institute{ - Server: *conv, - // TODO: delisted - Delisted: false, - }, - Type: srvtypes.TypeInstituteAccess, - }, nil - case server.CustomServerType: - curr, err := srvs.GetCustomServer(srvs.CustomServers.CurrentURL) + case srvtypes.TypeSecureInternet: + dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier) if err != nil { - return nil, err + return err } - conv, err := convertGeneric(*curr) + srv, err = c.Servers.AddSecureInternet(ck.Context(), dOrg, dSrv) if err != nil { - return nil, err + return err + } + case srvtypes.TypeCustom: + srv, err = c.Servers.AddCustom(ck.Context(), identifier) + if err != nil { + return err + } + default: + return errors.Errorf("not a valid server type: %v", _type) + } + + // if we are non interactive, we run no callbacks + if ni { + return nil + } + + // callbacks + return c.callbacks(ck, srv, false) +} + +func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool) (cfg *srvtypes.Configuration, err error) { + // do the callbacks to ensure valid profile, location and authorization + err = c.callbacks(ck, srv, forceAuth) + if err != nil { + return nil, err + } + + t := c.FSM.GoTransition(StateRequestConfig) + if !t { + log.Logger.Warningf("transition not completed for requesting config") + } + + err = c.profileCallback(ck, srv) + if err != nil { + return nil, err + } + + cfgS, err := server.Config(ck.Context(), srv, c.SupportsWireguard, pTCP) + if err != nil { + return nil, err + } + p, err := server.CurrentProfile(srv) + if err != nil { + return nil, err + } + pcfg := cfgS.Public(p.DefaultGateway) + if err != nil { + return nil, err + } + return &pcfg, nil +} + +func (c *Client) server(identifier string, _type srvtypes.Type) (srv server.Server, setter func(server.Server) error, err error) { + switch _type { + case srvtypes.TypeInstituteAccess: + srv, err = c.Servers.InstituteAccess(identifier) + setter = c.Servers.SetInstituteAccess + case srvtypes.TypeSecureInternet: + srv, err = c.Servers.SecureInternet(identifier) + setter = c.Servers.SetSecureInternet + case srvtypes.TypeCustom: + srv, err = c.Servers.CustomServer(identifier) + setter = c.Servers.SetCustom + default: + return nil, nil, errors.Errorf("not a valid server type: %v", _type) + } + return srv, setter, err +} + +// GetConfig gets a VPN configuration +func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool) (cfg *srvtypes.Configuration, err error) { + defer func() { + if err == nil { + c.FSM.GoTransition(StateGotConfig) + } else { + // go back if an error occurred + c.FSM.GoTransition(StateNoServer) + } + }() + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return nil, err + } + t := c.FSM.GoTransition(StateLoadingServer) + if !t { + log.Logger.Warningf("transition not completed for loading server") + } + srv, set, err := c.server(identifier, _type) + if err != nil { + return nil, err + } + // refresh the server endpoints + err = server.RefreshEndpoints(ck.Context(), srv) + if err != nil { + log.Logger.Warningf("failed to refresh server endpoints: %v", err) + } + + // get a config and retry with authorization if expired + cfg, err = c.config(ck, srv, pTCP, false) + tErr := &oauth.TokensInvalidError{} + if err != nil && errors.As(err, &tErr) { + cfg, err = c.config(ck, srv, pTCP, true) + } + + // still an error, return nil with the error + if err != nil { + return nil, err + } + + // set the current server + if err = set(srv); err != nil { + return nil, err + } + + return cfg, nil +} + +func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) { + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return err + } + switch _type { + case srvtypes.TypeInstituteAccess: + return c.Servers.RemoveInstituteAccess(identifier) + case srvtypes.TypeSecureInternet: + return c.Servers.RemoveSecureInternet(identifier) + case srvtypes.TypeCustom: + return c.Servers.RemoveCustom(identifier) + default: + return errors.Errorf("not a valid server type: %v", _type) + } +} + +func (c *Client) CurrentServer() (*srvtypes.Current, error) { + if !c.FSM.InState(StateGotConfig) { + return nil, errors.Errorf("State: %s, cannot have a current server. Did you get a VPN configuration?", GetStateName(c.FSM.Current)) + } + srv, err := c.Servers.Current() + if err != nil { + return nil, err + } + return c.pubCurrentServer(srv) +} + +func (c *Client) pubCurrentServer(srv server.Server) (*srvtypes.Current, error) { + b, err := srv.Base() + if err != nil { + return nil, err + } + pub, err := srv.Public() + if err != nil { + return nil, err + } + switch t := pub.(type) { + case *srvtypes.Server: + if b.Type == srvtypes.TypeInstituteAccess { + return &srvtypes.Current{ + Institute: &srvtypes.Institute{ + Server: *t, + // TODO: delisted + Delisted: false, + }, + Type: srvtypes.TypeInstituteAccess, + }, nil } return &srvtypes.Current{ - Custom: conv, + Custom: t, Type: srvtypes.TypeCustom, }, nil - case server.SecureInternetServerType: - b, err := c.Servers.SecureInternetHomeServer.Base() + case *srvtypes.SecureInternet: + t.Locations = c.Discovery.SecureLocationList() + return &srvtypes.Current{ + SecureInternet: t, + Type: srvtypes.TypeSecureInternet, + }, nil + default: + panic("unknown type") + } +} + +// TODO: This should not rely on interface{} +func (c *Client) pubServer(srv server.Server) (interface{}, error) { + pub, err := srv.Public() + if err != nil { + return nil, err + } + switch t := pub.(type) { + case *srvtypes.Server: + b, err := srv.Base() if err != nil { return nil, err } - generic := srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: c.Servers.SecureInternetHomeServer.HomeOrganizationID, - Profiles: convertProfiles(b.Profiles), - } - cc := c.Servers.SecureInternetHomeServer.CurrentLocation - return &srvtypes.Current{ - SecureInternet: &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, + if b.Type == srvtypes.TypeInstituteAccess { + return &srvtypes.Institute{ + Server: *t, // TODO: delisted Delisted: false, - }, - Type: srvtypes.TypeSecureInternet, - }, nil + }, nil + } + return t, nil + case *srvtypes.SecureInternet: + t.Locations = c.Discovery.SecureLocationList() + return t, nil default: - return nil, errors.New("current server not found") + panic("unknown type") + } +} + +func (c *Client) ServerList() (*srvtypes.List, error) { + if c.FSM.InState(StateDeregistered) { + return nil, errors.New("client is not registered") + } + var customServers []srvtypes.Server + for _, v := range c.Servers.CustomServers.Map { + if v == nil { + continue + } + p, err := c.pubServer(v) + if err != nil { + continue + } + c, ok := p.(*srvtypes.Server) + if !ok { + continue + } + customServers = append(customServers, *c) } + var instituteServers []srvtypes.Institute + for _, v := range c.Servers.CustomServers.Map { + if v == nil { + continue + } + p, err := c.pubServer(v) + if err != nil { + continue + } + i, ok := p.(*srvtypes.Institute) + if !ok { + continue + } + instituteServers = append(instituteServers, *i) + } + var secureInternet *srvtypes.SecureInternet + if c.Servers.HasSecureInternet() { + srv := &c.Servers.SecureInternetHomeServer + p, err := c.pubServer(srv) + if err == nil { + s, ok := p.(*srvtypes.SecureInternet) + if ok { + secureInternet = s + } + } + } + return &srvtypes.List{ + Institutes: instituteServers, + SecureInternet: secureInternet, + Custom: customServers, + }, nil +} + +func (c *Client) SetProfileID(pID string) (err error) { + srv, err := c.Servers.Current() + if err != nil { + return err + } + return server.Profile(srv, pID) +} + +func (c *Client) Cleanup(ck *cookie.Cookie) (err error) { + // get the current server + srv, err := c.Servers.Current() + if err != nil { + return err + } + // TODO: Support cookie context here + // if server.NeedsRelogin(context.Background(), srv) { + // // TODO: ask client for tokens + // } + err = server.Disconnect(ck.Context(), srv) + if err != nil { + return err + } + // TODO: Set tokens with callback + return nil +} + +func (c *Client) SetSecureLocation(ck *cookie.Cookie, countryCode string) (err error) { + if c.isLetsConnect() { + return errors.Errorf("setting a secure internet location with Let's Connect! is not supported") + } + + if !c.Servers.HasSecureInternet() { + return errors.Errorf("no secure internet server available to set a location for") + } + + dSrv, err := c.Discovery.ServerByCountryCode(countryCode) + if err != nil { + return err + } + + return c.Servers.SecureInternetHomeServer.Location(ck.Context(), dSrv) +} + +func (c *Client) RenewSession(ck *cookie.Cookie) (err error) { + srv, err := c.Servers.Current() + if err != nil { + return err + } + // The server has not been chosen yet, this means that we want to manually renew + // TODO: is this needed? + if !c.FSM.InState(StateChosenServer) { + c.FSM.GoTransition(StateLoadingServer) + c.FSM.GoTransition(StateChosenServer) + } + // TODO: Maybe this can be deleted because we force auth now + server.MarkTokensForRenew(srv) + // run the callbacks by forcing auth + return c.callbacks(ck, srv, true) +} + +func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, wgMTU int, readRxBytes func() (int64, error)) (bool, error) { + if c.Failover != nil { + return false, errors.New("another failover process is already started") + } + + c.Failover = failover.New(readRxBytes) + + return c.Failover.Start(ck.Context(), gateway, wgMTU) } diff --git a/client/client_test.go b/client/client_test.go index 56c38ff..7077ce4 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "fmt" "net/http" "net/url" @@ -12,6 +13,7 @@ import ( "time" httpw "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/types/cookie" "github.com/eduvpn/eduvpn-common/types/protocol" srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" @@ -22,7 +24,7 @@ func getServerURI(t *testing.T) string { if serverURI == "" { t.Skip("Skipping server test as no SERVER_URI env var has been passed") } - serverURI, parseErr := httpw.EnsureValidURL(serverURI) + serverURI, parseErr := httpw.EnsureValidURL(serverURI, true) if parseErr != nil { t.Skip("Skipping server test as the server uri is not valid") } @@ -41,13 +43,13 @@ func runCommand(errBuffer *strings.Builder, name string, args ...string) error { return cmd.Wait() } -func loginOAuthSelenium(url string, state *Client) { +func loginOAuthSelenium(ck *cookie.Cookie, url string) { // We could use the go selenium library // But it does not support the latest selenium v4 just yet var errBuffer strings.Builder err := runCommand(&errBuffer, "python3", "../selenium_eduvpn.py", url) if err != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() panic(fmt.Sprintf( "Login OAuth with selenium script failed with error %v and stderr %s", err, @@ -58,10 +60,10 @@ func loginOAuthSelenium(url string, state *Client) { func stateCallback( t *testing.T, + ck *cookie.Cookie, _ FSMStateID, newState FSMStateID, data interface{}, - state *Client, ) { if newState == StateOAuthStarted { url, ok := data.(string) @@ -69,20 +71,20 @@ func stateCallback( if !ok { t.Fatalf("data is not a string for OAuth URL") } - loginOAuthSelenium(url, state) + loginOAuthSelenium(ck, url) } } func TestServer(t *testing.T) { serverURI := getServerURI(t) - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configstest", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -95,12 +97,11 @@ func TestServer(t *testing.T) { t.Fatalf("Registering error: %v", err) } - - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } - _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Connect error: %v", configErr) } @@ -112,33 +113,36 @@ func testConnectOAuthParameter( errPrefix string, ) { serverURI := getServerURI(t) - state := &Client{} configDirectory := "test_oauth_parameters" + state := &Client{} + + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", configDirectory, func(oldState FSMStateID, newState FSMStateID, data interface{}) bool { if newState == StateOAuthStarted { - server, serverErr := state.Servers.GetCustomServer(serverURI) + server, serverErr := state.Servers.CustomServer(serverURI) if serverErr != nil { t.Fatalf("No server with error: %v", serverErr) } port, portErr := server.OAuth().ListenerPort() if portErr != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Fatalf("No port with error: %v", portErr) } baseURL := fmt.Sprintf("http://127.0.0.1:%d/callback", port) p, err := url.Parse(baseURL) if err != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Fatalf("Failed to parse URL with error: %v", err) } url, err := httpw.ConstructURL(p, parameters) if err != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Fatalf( "Error: Constructing url %s with parameters %s", baseURL, @@ -148,7 +152,7 @@ func testConnectOAuthParameter( go func() { _, getErr := http.Get(url) if getErr != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Logf("HTTP GET error: %v", getErr) } }() @@ -165,7 +169,7 @@ func testConnectOAuthParameter( t.Fatalf("Registering error: %v", err) } - err = state.AddCustomServer(serverURI) + err = state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if errPrefix == "" { if err != nil { @@ -247,14 +251,14 @@ func TestTokenExpired(t *testing.T) { } // Get a vpn state - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configsexpired", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -267,25 +271,25 @@ func TestTokenExpired(t *testing.T) { t.Fatalf("Registering error: %v", err) } - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } - _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Connect error before expired: %v", configErr) } - currentServer, serverErr := state.Servers.GetCurrentServer() + currentServer, serverErr := state.Servers.Current() if serverErr != nil { t.Fatalf("No server found") } serverOAuth := currentServer.OAuth() - accessToken, accessTokenErr := serverOAuth.AccessToken() + accessToken, accessTokenErr := serverOAuth.AccessToken(ck.Context()) if accessTokenErr != nil { t.Fatalf("Failed to get token: %v", accessTokenErr) } @@ -293,14 +297,14 @@ func TestTokenExpired(t *testing.T) { // Wait for TTL so that the tokens expire time.Sleep(time.Duration(expiredInt) * time.Second) - _, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Connect error after expiry: %v", configErr) } // Check if tokens have changed - accessTokenAfter, accessTokenAfterErr := serverOAuth.AccessToken() + accessTokenAfter, accessTokenAfterErr := serverOAuth.AccessToken(ck.Context()) if accessTokenAfterErr != nil { t.Fatalf("Failed to get token: %v", accessTokenAfterErr) } @@ -313,14 +317,14 @@ func TestTokenExpired(t *testing.T) { // Test if an invalid profile will be corrected. func TestInvalidProfileCorrected(t *testing.T) { serverURI := getServerURI(t) - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configscancelprofile", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -333,18 +337,18 @@ func TestInvalidProfileCorrected(t *testing.T) { t.Fatalf("Registering error: %v", err) } - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } - _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("First connect error: %v", configErr) } - currentServer, serverErr := state.Servers.GetCurrentServer() + currentServer, serverErr := state.Servers.Current() if serverErr != nil { t.Fatalf("No server found") } @@ -357,7 +361,7 @@ func TestInvalidProfileCorrected(t *testing.T) { previousProfile := base.Profiles.Current base.Profiles.Current = "IDONOTEXIST" - _, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Second connect error: %v", configErr) } @@ -374,14 +378,14 @@ func TestInvalidProfileCorrected(t *testing.T) { // Test if prefer tcp is handled correctly by checking the returned config and config type. func TestPreferTCP(t *testing.T) { serverURI := getServerURI(t) - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configsprefertcp", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -394,13 +398,13 @@ func TestPreferTCP(t *testing.T) { t.Fatalf("Registering error: %v", err) } - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } // get a config with preferTCP set to true - config, configErr := state.GetConfigCustomServer(serverURI, true, srvtypes.Tokens{}) + config, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, true) // Test server should accept prefer TCP! if config.Protocol != protocol.OpenVPN { @@ -417,7 +421,7 @@ func TestPreferTCP(t *testing.T) { } // get a config with preferTCP set to false - config, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + config, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Config error: %v", configErr) } diff --git a/client/fsm.go b/client/fsm.go index 9f140e3..038e6cf 100644 --- a/client/fsm.go +++ b/client/fsm.go @@ -2,9 +2,6 @@ package client import ( "github.com/eduvpn/eduvpn-common/internal/fsm" - "github.com/eduvpn/eduvpn-common/internal/log" - "github.com/eduvpn/eduvpn-common/internal/server" - "github.com/go-errors/errors" ) type ( @@ -94,7 +91,6 @@ func newFSM( }, StateNoServer: FSMState{ Transitions: []FSMTransition{ - {To: StateNoServer, Description: "Reload list"}, {To: StateLoadingServer, Description: "User clicks a server in the UI"}, }, }, @@ -170,47 +166,3 @@ func newFSM( returnedFSM.Init(StateDeregistered, states, callback, directory, GetStateName, debug) return returnedFSM } - -// GoBack transitions the FSM back to the previous UI state, for now this is always the NO_SERVER state. -func (c *Client) GoBack() error { - if c.InFSMState(StateDeregistered) { - err := errors.Errorf("fsm attempt going back from 'StateDeregistered'") - c.logError(err) - return err - } - - // FIXME: Arbitrary back transitions don't work because we need the appropriate data - c.FSM.GoTransition(StateNoServer) - return nil -} - -// goBackInternal uses the public go back but logs an error if it happened. -func (c *Client) goBackInternal() { - err := c.GoBack() - if err != nil { - // TODO(jwijenbergh): Bit suspicious - logging level INFO, yet stacktrace logged. - log.Logger.Infof("failed going back: %s\nstacktrace:\n%s", err.Error(), err.(*errors.Error).ErrorStack()) - } -} - -// CancelOAuth cancels OAuth if one is in progress. -// If OAuth is not in progress, it returns an error. -// An error is also returned if OAuth is in progress, but it fails to cancel it. -func (c *Client) CancelOAuth() error { - if !c.InFSMState(StateOAuthStarted) { - return errors.Errorf("fsm attempt cancelling OAuth while in '%v'", c.FSM.Current) - } - - srv, err := c.Servers.GetCurrentServer() - if err != nil { - c.logError(err) - } else { - server.CancelOAuth(srv) - } - return err -} - -// InFSMState is a helper to check if the FSM is in state `checkState`. -func (c *Client) InFSMState(checkState FSMStateID) bool { - return c.FSM.InState(checkState) -} diff --git a/client/server.go b/client/server.go deleted file mode 100644 index b3f7747..0000000 --- a/client/server.go +++ /dev/null @@ -1,701 +0,0 @@ -package client - -import ( - "time" - - "github.com/eduvpn/eduvpn-common/internal/failover" - "github.com/eduvpn/eduvpn-common/internal/http" - "github.com/eduvpn/eduvpn-common/internal/log" - "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/internal/server" - discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - "github.com/eduvpn/eduvpn-common/types/protocol" - srvtypes "github.com/eduvpn/eduvpn-common/types/server" - "github.com/go-errors/errors" -) - -// TODO: This should not be reliant on an internal type -func getTokens(tok oauth.Token) srvtypes.Tokens { - return srvtypes.Tokens{ - Access: tok.Access, - Refresh: tok.Refresh, - Expires: tok.ExpiredTimestamp.Unix(), - } -} - -// getConfigAuth gets a config with authorization and authentication. -// It also asks for a profile if no valid profile is found. -func (c *Client) getConfigAuth(srv server.Server, preferTCP bool, t srvtypes.Tokens) (*srvtypes.Configuration, error) { - err := c.ensureLogin(srv, t) - if err != nil { - return nil, err - } - - // TODO(jwijenbergh): Should we check if it returns false? - c.FSM.GoTransition(StateRequestConfig) - - ok, err := server.HasValidProfile(srv, c.SupportsWireguard) - if err != nil { - return nil, err - } - - // No valid profile, ask for one - if !ok { - if err = c.askProfile(srv); err != nil { - return nil, err - } - } - - // The profile has been chosen - c.FSM.GoTransition(StateChosenProfile) - - cfg, err := server.Config(srv, c.SupportsWireguard, preferTCP) - if err != nil { - return nil, err - } - - p, err := server.CurrentProfile(srv) - if err != nil { - return nil, err - } - - pCfg := &srvtypes.Configuration{ - VPNConfig: cfg.Config, - Protocol: protocol.New(cfg.Type), - DefaultGateway: p.DefaultGateway, - Tokens: getTokens(cfg.Tokens), - } - - return pCfg, nil -} - -// retryConfigAuth retries the getConfigAuth function if the tokens are invalid. -// If OAuth is cancelled, it makes sure that we only forward the error as additional info. -func (c *Client) retryConfigAuth(srv server.Server, preferTCP bool, t srvtypes.Tokens) (*srvtypes.Configuration, error) { - cfg, err := c.getConfigAuth(srv, preferTCP, t) - if err == nil { - return cfg, nil - } - // Only retry if the error is that the tokens are invalid - tErr := &oauth.TokensInvalidError{} - if errors.As(err, &tErr) { - // TODO: Is passing empty tokens correct here? - cfg, err = c.getConfigAuth(srv, preferTCP, srvtypes.Tokens{}) - if err == nil { - return cfg, nil - } - } - c.goBackInternal() - return nil, err -} - -// getConfig gets an OpenVPN/WireGuard configuration by contacting the server, moving the FSM towards the DISCONNECTED state and then saving the local configuration file. -func (c *Client) getConfig(srv server.Server, preferTCP bool, t srvtypes.Tokens) (*srvtypes.Configuration, error) { - if c.InFSMState(StateDeregistered) { - return nil, errors.Errorf("getConfig attempt in '%v'", StateDeregistered) - } - - // Refresh the server endpoints - // This is the best effort - err := srv.RefreshEndpoints(&c.Discovery) - if err != nil { - log.Logger.Warningf("failed to refresh server endpoints: %v", err) - } - - cfg, err := c.retryConfigAuth(srv, preferTCP, t) - if err != nil { - return nil, err - } - - // Save the config - if err = c.Config.Save(&c); err != nil { - // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace... - // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well. - log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s", - err.Error(), err.(*errors.Error).ErrorStack()) - } - - c.FSM.GoTransition(StateGotConfig) - - return cfg, nil -} - -// Cleanup cleans up the VPN connection by sending a /disconnect to the server -func (c *Client) Cleanup(ct srvtypes.Tokens) error { - srv, err := c.Servers.GetCurrentServer() - if err != nil { - c.logError(err) - return err - } - err = srv.RefreshEndpoints(&c.Discovery) - if err != nil { - log.Logger.Warningf("failed to refresh server endpoints: %v", err) - } - - // If we need to relogin, update tokens - if server.NeedsRelogin(srv) { - server.UpdateTokens(srv, oauth.Token{ - Access: ct.Access, - Refresh: ct.Refresh, - ExpiredTimestamp: time.Unix(ct.Expires, 0), - }) - } - // update tokens to client - defer c.ForwardTokenUpdate(srv) - // Do the /disconnect API call - err = server.Disconnect(srv) - if err != nil { - // We log nothing here because this can happen regularly - // Maybe we should not log errors that we return directly anyways? - return err - } - // TODO: Tokens might be refreshed, return updated tokens - // Not implemented yet, because ideally we want this implemented with an interface - return nil -} - -// SetSecureLocation sets the location for the current secure location server. countryCode is the secure location to be chosen. -// This function returns an error e.g. if the server cannot be found or the location is wrong. -func (c *Client) SetSecureLocation(countryCode string) error { - if c.InFSMState(StateAskLocation) { - defer c.locationWg.Done() - } - // Not supported with Let's Connect! - if c.isLetsConnect() { - err := errors.Errorf("discovery with Let's Connect is not supported") - c.logError(err) - return err - } - - srv, err := c.Discovery.ServerByCountryCode(countryCode) - if err != nil { - c.goBackInternal() - c.logError(err) - return err - } - - if err = c.Servers.SetSecureLocation(srv); err != nil { - c.goBackInternal() - c.logError(err) - } - - return err -} - -// RemoveSecureInternet removes the current secure internet server. -// It returns an error if the server cannot be removed due to the state being DEREGISTERED. -// Note that if the server does not exist, it returns nil as an error. -func (c *Client) RemoveSecureInternet() error { - if c.InFSMState(StateDeregistered) { - err := errors.Errorf("RemoveSecureInternet attempt in '%v'", StateDeregistered) - c.logError(err) - return err - } - // No error because we can only have one secure internet server and if there are no secure internet servers, this is a NO-OP - c.Servers.RemoveSecureInternet() - c.FSM.GoTransition(StateNoServer) - // Save the config - if err := c.Config.Save(&c); err != nil { - // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace... - // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well. - log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s", - err.Error(), err.(*errors.Error).ErrorStack()) - } - return nil -} - -// RemoveInstituteAccess removes the institute access server with `url`. -// It returns an error if the server cannot be removed due to the state being DEREGISTERED. -// Note that if the server does not exist, it returns nil as an error. -func (c *Client) RemoveInstituteAccess(url string) error { - if c.InFSMState(StateDeregistered) { - err := errors.Errorf("RemoveInstituteAccess attempt in '%v'", StateDeregistered) - c.logError(err) - return err - } - // No error because this is a NO-OP if the server doesn't exist - c.Servers.RemoveInstituteAccess(url) - c.FSM.GoTransition(StateNoServer) - // Save the config - if err := c.Config.Save(&c); err != nil { - // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace... - // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well. - log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s", - err.Error(), err.(*errors.Error).ErrorStack()) - } - return nil -} - -// RemoveCustomServer removes the custom server with `url`. -// It returns an error if the server cannot be removed due to the state being DEREGISTERED. -// Note that if the server does not exist, it returns nil as an error. -func (c *Client) RemoveCustomServer(url string) error { - if c.InFSMState(StateDeregistered) { - err := errors.Errorf("RemoveCustomServer attempt in '%v'", StateDeregistered) - c.logError(err) - return err - } - // No error because this is a NO-OP if the server doesn't exist - c.Servers.RemoveCustomServer(url) - c.FSM.GoTransition(StateNoServer) - // Save the config - if err := c.Config.Save(&c); err != nil { - // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace... - // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well. - log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s", - err.Error(), err.(*errors.Error).ErrorStack()) - } - return nil -} - -// AddInstituteServer adds an Institute Access server by `url`. -func (c *Client) AddInstituteServer(url string) (err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - // Not supported with Let's Connect! - if c.isLetsConnect() { - return errors.Errorf("adding and Institute Access server with Let's Connect is not supported") - } - - // Indicate that we're loading the server - c.FSM.GoTransition(StateLoadingServer) - - // Check if we are able to fetch discovery, and log if something went wrong - if _, err := c.DiscoServers(); err != nil { - log.Logger.Warningf("Failed to get discovery servers: %v", err) - } - - if _, err := c.DiscoOrganizations(); err != nil { - log.Logger.Warningf("Failed to get discovery organizations: %v", err) - } - - // FIXME: Do nothing with discovery here as the client already has it - // So pass a server as the parameter - var dSrv *discotypes.Server - dSrv, err = c.Discovery.ServerByURL(url, "institute_access") - if err != nil { - c.goBackInternal() - return err - } - - // Add the secure internet server - srv, err := c.Servers.AddInstituteAccessServer(dSrv) - if err != nil { - c.goBackInternal() - return err - } - - // Set the server as the current so OAuth can be cancelled - if err = c.Servers.SetInstituteAccess(srv); err != nil { - c.goBackInternal() - return err - } - - // Indicate that we want to authorize this server - c.FSM.GoTransition(StateChosenServer) - - // Authorize it - if err = c.ensureLogin(srv, srvtypes.Tokens{}); err != nil { - // Removing is best effort - _ = c.RemoveInstituteAccess(url) - return err - } - - c.FSM.GoTransition(StateNoServer) - return nil -} - -// AddSecureInternetHomeServer adds a Secure Internet Home Server with `orgID` that was obtained from the Discovery file. -// Because there is only one Secure Internet Home Server, it replaces the existing one. -func (c *Client) AddSecureInternetHomeServer(orgID string) (err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - // Not supported with Let's Connect! - if c.isLetsConnect() { - return errors.Errorf("adding a secure internet server with Let's Connect is not supported") - } - - // Indicate that we're loading the server - c.FSM.GoTransition(StateLoadingServer) - - // Check if we are able to fetch discovery, and log if something went wrong - if _, err := c.DiscoServers(); err != nil { - log.Logger.Warningf("Failed to get discovery servers: %v", err) - } - - if _, err := c.DiscoOrganizations(); err != nil { - log.Logger.Warningf("Failed to get discovery organizations: %v", err) - } - - // Get the secure internet URL from discovery - org, dSrv, err := c.Discovery.SecureHomeArgs(orgID) - if err != nil { - // We mark the organizations as expired because we got an error - // Note that in the docs it states that it only should happen when the Org ID doesn't exist - // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found - c.Discovery.MarkOrganizationsExpired() - c.goBackInternal() - return err - } - - // Add the secure internet server - srv, err := c.Servers.AddSecureInternet(org, dSrv) - if err != nil { - c.goBackInternal() - return err - } - - // TODO(jwijenbergh): Does this call transfers execution flow to UI? - if err = c.askSecureLocation(); err != nil { - // Removing is the best effort - // This already goes back to the main screen - _ = c.RemoveSecureInternet() - return err - } - - c.FSM.GoTransition(StateChosenLocation) - - // Set the server as the current so OAuth can be cancelled - if err = c.Servers.SetSecureInternet(srv); err != nil { - c.goBackInternal() - return err - } - - // Server has been chosen for authentication - c.FSM.GoTransition(StateChosenServer) - - // Authorize it - if err = c.ensureLogin(srv, srvtypes.Tokens{}); err != nil { - // Removing is best effort - _ = c.RemoveSecureInternet() - return err - } - c.FSM.GoTransition(StateNoServer) - return nil -} - -// AddCustomServer adds a Custom Server by `url`. -func (c *Client) AddCustomServer(url string) (err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - if url, err = http.EnsureValidURL(url); err != nil { - return err - } - - // Indicate that we're loading the server - c.FSM.GoTransition(StateLoadingServer) - - customServer := &discotypes.Server{ - BaseURL: url, - DisplayName: map[string]string{"en": url}, - Type: "custom_server", - } - - // A custom server is just an institute access server under the hood - srv, err := c.Servers.AddCustomServer(customServer) - if err != nil { - c.goBackInternal() - return err - } - - // Set the server as the current so OAuth can be cancelled - if err = c.Servers.SetCustomServer(srv); err != nil { - c.goBackInternal() - return err - } - - // Server has been chosen for authentication - c.FSM.GoTransition(StateChosenServer) - - // Authorize it - if err = c.ensureLogin(srv, srvtypes.Tokens{}); err != nil { - // removing is best effort - _ = c.RemoveCustomServer(url) - return err - } - - c.FSM.GoTransition(StateNoServer) - return nil -} - -// GetConfigInstituteAccess gets a configuration for an Institute Access Server. -// It ensures that the Institute Access Server exists by creating or using an existing one with the url. -// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool, t srvtypes.Tokens) (cfg *srvtypes.Configuration, err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - // Not supported with Let's Connect! - if c.isLetsConnect() { - return nil, errors.Errorf("discovery with Let's Connect is not supported") - } - - c.FSM.GoTransition(StateLoadingServer) - - // Get the server if it exists - var srv *server.InstituteAccessServer - if srv, err = c.Servers.GetInstituteAccess(url); err != nil { - c.goBackInternal() - return nil, err - } - - // Set the server as the current - if err = c.Servers.SetInstituteAccess(srv); err != nil { - return nil, err - } - - // The server has now been chosen - c.FSM.GoTransition(StateChosenServer) - - if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { - c.goBackInternal() - } - - // Also forward tokens using the callback - c.ForwardTokenUpdate(srv) - - return cfg, err -} - -// GetConfigSecureInternet gets a configuration for a Secure Internet Server. -// It ensures that the Secure Internet Server exists by creating or using an existing one with the orgID. -// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -// TODO: Check on first argument orgID -func (c *Client) GetConfigSecureInternet(_ string, preferTCP bool, t srvtypes.Tokens) (cfg *srvtypes.Configuration, err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - log.Logger.Debugf("getting config for secure internet server with org ID: '%s", orgID) - - // Not supported with Let's Connect! - if c.isLetsConnect() { - return nil, errors.Errorf("discovery with Let's Connect is not supported") - } - - c.FSM.GoTransition(StateLoadingServer) - - // Get the server if it exists - var srv *server.SecureInternetHomeServer - if srv, err = c.Servers.GetSecureInternetHomeServer(); err != nil { - c.goBackInternal() - return nil, err - } - - // Set the server as the current - if err = c.Servers.SetSecureInternet(srv); err != nil { - return nil, err - } - - c.FSM.GoTransition(StateChosenServer) - - if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { - c.goBackInternal() - } - - // Also forward tokens using the callback - c.ForwardTokenUpdate(srv) - - return cfg, err -} - -// GetConfigCustomServer gets a configuration for a Custom Server. -// It ensures that the Custom Server exists by creating or using an existing one with the url. -// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (c *Client) GetConfigCustomServer(url string, preferTCP bool, t srvtypes.Tokens) (cfg *srvtypes.Configuration, err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - if url, err = http.EnsureValidURL(url); err != nil { - return nil, err - } - - c.FSM.GoTransition(StateLoadingServer) - - // Get the server if it exists - var srv *server.InstituteAccessServer - if srv, err = c.Servers.GetCustomServer(url); err != nil { - c.goBackInternal() - return nil, err - } - - // Set the server as the current - if err = c.Servers.SetCustomServer(srv); err != nil { - c.goBackInternal() - return nil, err - } - - c.FSM.GoTransition(StateChosenServer) - - if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { - c.goBackInternal() - } - - // Also forward tokens using the callback - c.ForwardTokenUpdate(srv) - - return cfg, err -} - -// askSecureLocation asks the user to choose a Secure Internet location by moving the FSM to the STATE_ASK_LOCATION state. -func (c *Client) askSecureLocation() error { - loc := c.Discovery.SecureLocationList() - - c.locationWg.Add(1) - // Ask for the location in the callback - if err := c.FSM.GoTransitionRequired(StateAskLocation, loc); err != nil { - return err - } - - c.locationWg.Wait() - - // The state has changed, meaning setting the secure location was not successful - if c.FSM.Current != StateAskLocation { - log.Logger.Debugf("fsm failed to transit; expected %v / actual %v", GetStateName(StateAskLocation), GetStateName(c.FSM.Current)) - return errors.New("failed loading secure internet location") - } - return nil -} - -// RenewSession renews the session for the current VPN server. -// This logs the user back in. -func (c *Client) RenewSession() (err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - var srv server.Server - if srv, err = c.Servers.GetCurrentServer(); err != nil { - return err - } - - err = srv.RefreshEndpoints(&c.Discovery) - if err != nil { - log.Logger.Warningf("failed to refresh server endpoints: %v", err) - } - - // The server has not been chosen yet, this means that we want to manually renew - if !c.FSM.InState(StateChosenServer) { - c.FSM.GoTransition(StateLoadingServer) - c.FSM.GoTransition(StateChosenServer) - } - - server.MarkTokensForRenew(srv) - return c.ensureLogin(srv, srvtypes.Tokens{}) -} - -// ensureLogin logs the user back in if needed. -// It runs the FSM transitions to ask for user input. -func (c *Client) ensureLogin(srv server.Server, t srvtypes.Tokens) (err error) { - // Relogin with oauth - // This moves the state to authorized - if !server.NeedsRelogin(srv) { - // OAuth was valid, ensure we are in the authorized state - c.FSM.GoTransition(StateAuthorized) - return nil - } - - // Try again but update the tokens using the client provided tokens - server.UpdateTokens(srv, oauth.Token{ - Access: t.Access, - Refresh: t.Refresh, - ExpiredTimestamp: time.Unix(t.Expires, 0), - }) - if !server.NeedsRelogin(srv) { - // OAuth was valid, ensure we are in the authorized state - c.FSM.GoTransition(StateAuthorized) - return nil - } - - // Mark organizations as expired if the server is a secure internet server - b, err := srv.Base() - // We only try to update it when we found the server base - if err == nil && b.Type == "secure_internet" { - c.Discovery.MarkOrganizationsExpired() - } - - // Tokens are not valid or the client gave an error when updating tokens - // Otherwise, do the OAuth exchange - var url string - if url, err = server.OAuthURL(srv, c.Name); err != nil { - return err - } - - if err = c.FSM.GoTransitionRequired(StateOAuthStarted, url); err != nil { - return err - } - - if err = server.OAuthExchange(srv); err != nil { - c.goBackInternal() - } - c.FSM.GoTransition(StateAuthorized) - - return err -} - -// SetProfileID sets a `profileID` for the current server. -// An error is returned if this is not possible, for example when no server is configured. -func (c *Client) SetProfileID(profileID string) (err error) { - if c.InFSMState(StateAskProfile) { - defer c.profileWg.Done() - } - defer func() { - if err != nil { - c.logError(err) - } - }() - - var srv server.Server - if srv, err = c.Servers.GetCurrentServer(); err != nil { - c.goBackInternal() - return err - } - - var b *server.Base - if b, err = srv.Base(); err != nil { - c.goBackInternal() - return err - } - b.Profiles.Current = profileID - - return nil -} - -func (c *Client) StartFailover(gateway string, wgMTU int, readRxBytes func() (int64, error)) (bool, error) { - if c.Failover != nil { - return false, errors.New("another failover process is already started") - } - c.Failover = failover.New(readRxBytes) - - return c.Failover.Start(gateway, wgMTU) -} - -func (c *Client) CancelFailover() error { - if c.Failover == nil { - return errors.New("no failover process") - } - c.Failover.Cancel() - return nil -} |
