diff options
| -rw-r--r-- | client/client.go | 622 | ||||
| -rw-r--r-- | client/client_test.go | 84 | ||||
| -rw-r--r-- | client/fsm.go | 48 | ||||
| -rw-r--r-- | client/server.go | 701 |
4 files changed, 513 insertions, 942 deletions
diff --git a/client/client.go b/client/client.go index 70adb71..813f6dc 100644 --- a/client/client.go +++ b/client/client.go @@ -2,9 +2,9 @@ package client import ( + "context" "fmt" "strings" - "sync" "github.com/eduvpn/eduvpn-common/internal/config" "github.com/eduvpn/eduvpn-common/internal/discovery" @@ -14,32 +14,12 @@ import ( "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" + "github.com/eduvpn/eduvpn-common/types/cookie" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - "github.com/eduvpn/eduvpn-common/types/protocol" srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" ) -type ( - // ServerBase is an alias to the internal ServerBase - // This contains the details for each server. - ServerBase = server.Base -) - -func (c *Client) logError(err error) { - // Logs the error with the same level/verbosity as the error - if c.Debug { - log.Logger.Inherit(err, fmt.Sprintf("\nwith stacktrace: %s\n", err.(*errors.Error).ErrorStack())) - } else { - log.Logger.Inherit(err, "") - } -} - -func (c *Client) isLetsConnect() bool { - // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php - return strings.HasPrefix(c.Name, "org.letsconnect-vpn.app") -} - // isAllowedClientID checks if the 'clientID' is in the list of allowed client IDs func isAllowedClientID(clientID string) bool { allowList := []string{ @@ -91,13 +71,27 @@ func userAgentName(clientID string) string { } } +func (c *Client) logError(err error) { + // Logs the error with the same level/verbosity as the error + if c.Debug { + log.Logger.Inherit(err, fmt.Sprintf("\nwith stacktrace: %s\n", err.(*errors.Error).ErrorStack())) + } else { + log.Logger.Inherit(err, "") + } +} + +func (c *Client) isLetsConnect() bool { + // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php + return strings.HasPrefix(c.Name, "org.letsconnect-vpn.app") +} + // Client is the main struct for the VPN client. type Client struct { // The name of the client Name string `json:"-"` // The chosen server - Servers server.Servers `json:"servers"` + Servers server.List `json:"servers"` // The list of servers and organizations from disco Discovery discovery.Discovery `json:"discovery"` @@ -116,9 +110,6 @@ type Client struct { // The Failover monitor for the current VPN connection Failover *failover.DroppedConMon - - locationWg sync.WaitGroup - profileWg sync.WaitGroup } // New creates a new client with the following parameters: @@ -179,7 +170,7 @@ func New(name string, version string, directory string, stateCallback func(FSMSt // Registering means updating the FSM to get to the initial state correctly func (c *Client) Register() error { - if !c.InFSMState(StateDeregistered) { + if !c.FSM.InState(StateDeregistered) { return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current) } c.FSM.GoTransition(StateNoServer) @@ -200,27 +191,11 @@ func (c *Client) Deregister() { *c = Client{} } -// askProfile asks the user for a profile by moving the FSM to the ASK_PROFILE state. -func (c *Client) askProfile(srv server.Server) error { - ps, err := server.ValidProfiles(srv, c.SupportsWireguard) - if err != nil { - return err - } - - c.profileWg.Add(1) - if err = c.FSM.GoTransitionRequired(StateAskProfile, convertProfiles(*ps)); err != nil { - return err - } - c.profileWg.Wait() - - return nil -} - // DiscoOrganizations gets the organizations list from the discovery server // If the list cannot be retrieved an error is returned. // If this is the case then a previous version of the list is returned if there is any. // This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#organization-list. -func (c *Client) DiscoOrganizations() (orgs *discotypes.Organizations, err error) { +func (c *Client) DiscoOrganizations(ck *cookie.Cookie) (orgs *discotypes.Organizations, err error) { defer func() { if err != nil { c.logError(err) @@ -237,14 +212,15 @@ func (c *Client) DiscoOrganizations() (orgs *discotypes.Organizations, err error c.Discovery.MarkOrganizationsExpired() } - return c.Discovery.Organizations() + // TODO: pass a context + return c.Discovery.Organizations(ck.Context()) } // DiscoServers gets the servers list from the discovery server // If the list cannot be retrieved an error is returned. // If this is the case then a previous version of the list is returned if there is any. // This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#server-list. -func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { +func (c *Client) DiscoServers(ck *cookie.Cookie) (dss *discotypes.Servers, err error) { defer func() { if err != nil { c.logError(err) @@ -256,7 +232,8 @@ func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { return nil, errors.Errorf("discovery with Let's Connect is not supported") } - return c.Discovery.Servers() + // TODO: pass a context + return c.Discovery.Servers(ck.Context()) } // ExpiryTimes returns the different Unix timestamps regarding expiry @@ -266,7 +243,7 @@ func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { // These times are reset when the VPN gets disconnected func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { // Get current expiry time - srv, err := c.Servers.GetCurrentServer() + srv, err := c.Servers.Current() if err != nil { c.logError(err) return nil, err @@ -293,149 +270,488 @@ func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { }, nil } -func convertProfiles(profiles server.ProfileInfo) srvtypes.Profiles { - m := make(map[string]srvtypes.Profile) - for _, p := range profiles.Info.ProfileList { - var protocols []protocol.Protocol - // loop through all protocol strings - for _, ps := range p.VPNProtoList { - protocols = append(protocols, protocol.New(ps)) - } - m[p.ID] = srvtypes.Profile{ - DisplayName: map[string]string{ - "en": p.DisplayName, - }, - Protocols: protocols, +func (c *Client) locationCallback(ck *cookie.Cookie) error { + locs := c.Discovery.SecureLocationList() + errChan := make(chan error) + go func() { + err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{ + C: ck, + Data: locs, + }) + if err != nil { + errChan <- err } + }() + loc, err := ck.Receive(errChan) + if err != nil { + return err } - return srvtypes.Profiles{Map: m, Current: profiles.Current} + err = c.SetSecureLocation(ck, loc) + if err != nil { + return err + } + t := c.FSM.GoTransition(StateChosenLocation) + if !t { + log.Logger.Warningf("transition chosen location not completed") + } + return nil } -func convertGeneric(server server.InstituteAccessServer) (*srvtypes.Server, error) { - b, err := server.Base() +func (c *Client) loginCallback(ck *cookie.Cookie, srv server.Server) error { + url, err := server.OAuthURL(srv, c.Name) if err != nil { - return nil, err + return err } - return &srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: b.URL, - Profiles: convertProfiles(b.Profiles), - }, nil + err = c.FSM.GoTransitionRequired(StateOAuthStarted, url) + if err != nil { + return err + } + err = server.OAuthExchange(ck.Context(), srv) + if err != nil { + return err + } + return nil } -// TODO: CLEAN THIS UP -func (c *Client) ServerList() (*srvtypes.List, error) { - custom := c.Servers.CustomServers - var customServers []srvtypes.Server - for _, v := range custom.Map { - if v == nil { - return nil, errors.New("found nil value in custom server map") +func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool) error { + // location + if srv.NeedsLocation() { + err := c.locationCallback(ck) + if err != nil { + return err } - conv, err := convertGeneric(*v) + } + + t := c.FSM.GoTransition(StateChosenServer) + if !t { + log.Logger.Warningf("transition not completed for chosen server") + } + // oauth + // TODO: This should be ck.Context() + // But needsrelogin needs a rewrite to support this properly + if server.NeedsRelogin(context.Background(), srv) || forceauth { + err := c.loginCallback(ck, srv) if err != nil { - return nil, errors.Errorf("failed to convert custom server for public type: %v", err) + return err } - customServers = append(customServers, *conv) } - institute := c.Servers.InstituteServers - var instituteServers []srvtypes.Institute - for _, v := range institute.Map { - if v == nil { - return nil, errors.New("found nil value in institute server map") + t = c.FSM.GoTransition(StateAuthorized) + if !t { + log.Logger.Warningf("transition authorized not completed") + } + + return nil +} + +func (c *Client) profileCallback(ck *cookie.Cookie, srv server.Server) error { + vp, err := server.HasValidProfile(ck.Context(), srv, c.SupportsWireguard) + if err != nil { + return err + } + if !vp { + b, err := srv.Base() + if err != nil { + return err } - conv, err := convertGeneric(*v) + ps := b.Profiles.Public() + errChan := make(chan error) + go func() { + err := c.FSM.GoTransitionRequired(StateAskProfile, &srvtypes.RequiredAskTransition{ + C: ck, + Data: ps, + }) + if err != nil { + errChan <- err + } + }() + pID, err := ck.Receive(errChan) if err != nil { - return nil, errors.Errorf("failed to convert institute server for public type: %v", err) + return err + } + err = server.Profile(srv, pID) + if err != nil { + return err } - instituteServers = append(instituteServers, srvtypes.Institute{ - Server: *conv, - // TODO: delisted - Delisted: false, - }) } + t := c.FSM.GoTransition(StateChosenProfile) + if !t { + log.Logger.Warningf("transition chosen profile not completed") + } + return nil +} - var secureInternet *srvtypes.SecureInternet - if c.Servers.HasSecureInternet() { - b, err := c.Servers.SecureInternetHomeServer.Base() - if err == nil { - generic := srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: b.URL, - Profiles: convertProfiles(b.Profiles), - } - cc := c.Servers.SecureInternetHomeServer.CurrentLocation - secureInternet = &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, - // TODO: delisted - Delisted: false, - } +// AddServer adds a server with identifier and type +func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.Type, ni bool) (err error) { + // If we have failed to add the server, we remove it again + // We add the server because we can then obtain it in other callback functions + defer func() { + if err != nil { + _ = c.RemoveServer(identifier, _type) //nolint:errcheck } + c.FSM.GoTransition(StateNoServer) + }() + if !ni { + if !c.FSM.InState(StateNoServer) { + return errors.Errorf("wrong state to add a server: %s", GetStateName(c.FSM.Current)) + } + t := c.FSM.GoTransition(StateLoadingServer) + if !t { + log.Logger.Warningf("transition not completed for loading server") + } } - return &srvtypes.List{ - Institutes: instituteServers, - SecureInternet: secureInternet, - Custom: customServers, - }, nil -} -// TODO: CLEAN THIS UP -func (c *Client) CurrentServer() (*srvtypes.Current, error) { - srvs := c.Servers + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return err + } + + var srv server.Server - switch srvs.IsType { - case server.InstituteAccessServerType: - curr, err := srvs.GetInstituteAccess(srvs.InstituteServers.CurrentURL) + switch _type { + case srvtypes.TypeInstituteAccess: + dSrv, err := c.Discovery.ServerByURL(identifier, "institute_access") if err != nil { - return nil, err + return err } - conv, err := convertGeneric(*curr) + srv, err = c.Servers.AddInstituteAccess(ck.Context(), dSrv) if err != nil { - return nil, err + return err } - return &srvtypes.Current{ - Institute: &srvtypes.Institute{ - Server: *conv, - // TODO: delisted - Delisted: false, - }, - Type: srvtypes.TypeInstituteAccess, - }, nil - case server.CustomServerType: - curr, err := srvs.GetCustomServer(srvs.CustomServers.CurrentURL) + case srvtypes.TypeSecureInternet: + dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier) if err != nil { - return nil, err + return err } - conv, err := convertGeneric(*curr) + srv, err = c.Servers.AddSecureInternet(ck.Context(), dOrg, dSrv) if err != nil { - return nil, err + return err + } + case srvtypes.TypeCustom: + srv, err = c.Servers.AddCustom(ck.Context(), identifier) + if err != nil { + return err + } + default: + return errors.Errorf("not a valid server type: %v", _type) + } + + // if we are non interactive, we run no callbacks + if ni { + return nil + } + + // callbacks + return c.callbacks(ck, srv, false) +} + +func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool) (cfg *srvtypes.Configuration, err error) { + // do the callbacks to ensure valid profile, location and authorization + err = c.callbacks(ck, srv, forceAuth) + if err != nil { + return nil, err + } + + t := c.FSM.GoTransition(StateRequestConfig) + if !t { + log.Logger.Warningf("transition not completed for requesting config") + } + + err = c.profileCallback(ck, srv) + if err != nil { + return nil, err + } + + cfgS, err := server.Config(ck.Context(), srv, c.SupportsWireguard, pTCP) + if err != nil { + return nil, err + } + p, err := server.CurrentProfile(srv) + if err != nil { + return nil, err + } + pcfg := cfgS.Public(p.DefaultGateway) + if err != nil { + return nil, err + } + return &pcfg, nil +} + +func (c *Client) server(identifier string, _type srvtypes.Type) (srv server.Server, setter func(server.Server) error, err error) { + switch _type { + case srvtypes.TypeInstituteAccess: + srv, err = c.Servers.InstituteAccess(identifier) + setter = c.Servers.SetInstituteAccess + case srvtypes.TypeSecureInternet: + srv, err = c.Servers.SecureInternet(identifier) + setter = c.Servers.SetSecureInternet + case srvtypes.TypeCustom: + srv, err = c.Servers.CustomServer(identifier) + setter = c.Servers.SetCustom + default: + return nil, nil, errors.Errorf("not a valid server type: %v", _type) + } + return srv, setter, err +} + +// GetConfig gets a VPN configuration +func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool) (cfg *srvtypes.Configuration, err error) { + defer func() { + if err == nil { + c.FSM.GoTransition(StateGotConfig) + } else { + // go back if an error occurred + c.FSM.GoTransition(StateNoServer) + } + }() + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return nil, err + } + t := c.FSM.GoTransition(StateLoadingServer) + if !t { + log.Logger.Warningf("transition not completed for loading server") + } + srv, set, err := c.server(identifier, _type) + if err != nil { + return nil, err + } + // refresh the server endpoints + err = server.RefreshEndpoints(ck.Context(), srv) + if err != nil { + log.Logger.Warningf("failed to refresh server endpoints: %v", err) + } + + // get a config and retry with authorization if expired + cfg, err = c.config(ck, srv, pTCP, false) + tErr := &oauth.TokensInvalidError{} + if err != nil && errors.As(err, &tErr) { + cfg, err = c.config(ck, srv, pTCP, true) + } + + // still an error, return nil with the error + if err != nil { + return nil, err + } + + // set the current server + if err = set(srv); err != nil { + return nil, err + } + + return cfg, nil +} + +func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) { + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return err + } + switch _type { + case srvtypes.TypeInstituteAccess: + return c.Servers.RemoveInstituteAccess(identifier) + case srvtypes.TypeSecureInternet: + return c.Servers.RemoveSecureInternet(identifier) + case srvtypes.TypeCustom: + return c.Servers.RemoveCustom(identifier) + default: + return errors.Errorf("not a valid server type: %v", _type) + } +} + +func (c *Client) CurrentServer() (*srvtypes.Current, error) { + if !c.FSM.InState(StateGotConfig) { + return nil, errors.Errorf("State: %s, cannot have a current server. Did you get a VPN configuration?", GetStateName(c.FSM.Current)) + } + srv, err := c.Servers.Current() + if err != nil { + return nil, err + } + return c.pubCurrentServer(srv) +} + +func (c *Client) pubCurrentServer(srv server.Server) (*srvtypes.Current, error) { + b, err := srv.Base() + if err != nil { + return nil, err + } + pub, err := srv.Public() + if err != nil { + return nil, err + } + switch t := pub.(type) { + case *srvtypes.Server: + if b.Type == srvtypes.TypeInstituteAccess { + return &srvtypes.Current{ + Institute: &srvtypes.Institute{ + Server: *t, + // TODO: delisted + Delisted: false, + }, + Type: srvtypes.TypeInstituteAccess, + }, nil } return &srvtypes.Current{ - Custom: conv, + Custom: t, Type: srvtypes.TypeCustom, }, nil - case server.SecureInternetServerType: - b, err := c.Servers.SecureInternetHomeServer.Base() + case *srvtypes.SecureInternet: + t.Locations = c.Discovery.SecureLocationList() + return &srvtypes.Current{ + SecureInternet: t, + Type: srvtypes.TypeSecureInternet, + }, nil + default: + panic("unknown type") + } +} + +// TODO: This should not rely on interface{} +func (c *Client) pubServer(srv server.Server) (interface{}, error) { + pub, err := srv.Public() + if err != nil { + return nil, err + } + switch t := pub.(type) { + case *srvtypes.Server: + b, err := srv.Base() if err != nil { return nil, err } - generic := srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: c.Servers.SecureInternetHomeServer.HomeOrganizationID, - Profiles: convertProfiles(b.Profiles), - } - cc := c.Servers.SecureInternetHomeServer.CurrentLocation - return &srvtypes.Current{ - SecureInternet: &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, + if b.Type == srvtypes.TypeInstituteAccess { + return &srvtypes.Institute{ + Server: *t, // TODO: delisted Delisted: false, - }, - Type: srvtypes.TypeSecureInternet, - }, nil + }, nil + } + return t, nil + case *srvtypes.SecureInternet: + t.Locations = c.Discovery.SecureLocationList() + return t, nil default: - return nil, errors.New("current server not found") + panic("unknown type") + } +} + +func (c *Client) ServerList() (*srvtypes.List, error) { + if c.FSM.InState(StateDeregistered) { + return nil, errors.New("client is not registered") + } + var customServers []srvtypes.Server + for _, v := range c.Servers.CustomServers.Map { + if v == nil { + continue + } + p, err := c.pubServer(v) + if err != nil { + continue + } + c, ok := p.(*srvtypes.Server) + if !ok { + continue + } + customServers = append(customServers, *c) } + var instituteServers []srvtypes.Institute + for _, v := range c.Servers.CustomServers.Map { + if v == nil { + continue + } + p, err := c.pubServer(v) + if err != nil { + continue + } + i, ok := p.(*srvtypes.Institute) + if !ok { + continue + } + instituteServers = append(instituteServers, *i) + } + var secureInternet *srvtypes.SecureInternet + if c.Servers.HasSecureInternet() { + srv := &c.Servers.SecureInternetHomeServer + p, err := c.pubServer(srv) + if err == nil { + s, ok := p.(*srvtypes.SecureInternet) + if ok { + secureInternet = s + } + } + } + return &srvtypes.List{ + Institutes: instituteServers, + SecureInternet: secureInternet, + Custom: customServers, + }, nil +} + +func (c *Client) SetProfileID(pID string) (err error) { + srv, err := c.Servers.Current() + if err != nil { + return err + } + return server.Profile(srv, pID) +} + +func (c *Client) Cleanup(ck *cookie.Cookie) (err error) { + // get the current server + srv, err := c.Servers.Current() + if err != nil { + return err + } + // TODO: Support cookie context here + // if server.NeedsRelogin(context.Background(), srv) { + // // TODO: ask client for tokens + // } + err = server.Disconnect(ck.Context(), srv) + if err != nil { + return err + } + // TODO: Set tokens with callback + return nil +} + +func (c *Client) SetSecureLocation(ck *cookie.Cookie, countryCode string) (err error) { + if c.isLetsConnect() { + return errors.Errorf("setting a secure internet location with Let's Connect! is not supported") + } + + if !c.Servers.HasSecureInternet() { + return errors.Errorf("no secure internet server available to set a location for") + } + + dSrv, err := c.Discovery.ServerByCountryCode(countryCode) + if err != nil { + return err + } + + return c.Servers.SecureInternetHomeServer.Location(ck.Context(), dSrv) +} + +func (c *Client) RenewSession(ck *cookie.Cookie) (err error) { + srv, err := c.Servers.Current() + if err != nil { + return err + } + // The server has not been chosen yet, this means that we want to manually renew + // TODO: is this needed? + if !c.FSM.InState(StateChosenServer) { + c.FSM.GoTransition(StateLoadingServer) + c.FSM.GoTransition(StateChosenServer) + } + // TODO: Maybe this can be deleted because we force auth now + server.MarkTokensForRenew(srv) + // run the callbacks by forcing auth + return c.callbacks(ck, srv, true) +} + +func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, wgMTU int, readRxBytes func() (int64, error)) (bool, error) { + if c.Failover != nil { + return false, errors.New("another failover process is already started") + } + + c.Failover = failover.New(readRxBytes) + + return c.Failover.Start(ck.Context(), gateway, wgMTU) } diff --git a/client/client_test.go b/client/client_test.go index 56c38ff..7077ce4 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1,6 +1,7 @@ package client import ( + "context" "fmt" "net/http" "net/url" @@ -12,6 +13,7 @@ import ( "time" httpw "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/types/cookie" "github.com/eduvpn/eduvpn-common/types/protocol" srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" @@ -22,7 +24,7 @@ func getServerURI(t *testing.T) string { if serverURI == "" { t.Skip("Skipping server test as no SERVER_URI env var has been passed") } - serverURI, parseErr := httpw.EnsureValidURL(serverURI) + serverURI, parseErr := httpw.EnsureValidURL(serverURI, true) if parseErr != nil { t.Skip("Skipping server test as the server uri is not valid") } @@ -41,13 +43,13 @@ func runCommand(errBuffer *strings.Builder, name string, args ...string) error { return cmd.Wait() } -func loginOAuthSelenium(url string, state *Client) { +func loginOAuthSelenium(ck *cookie.Cookie, url string) { // We could use the go selenium library // But it does not support the latest selenium v4 just yet var errBuffer strings.Builder err := runCommand(&errBuffer, "python3", "../selenium_eduvpn.py", url) if err != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() panic(fmt.Sprintf( "Login OAuth with selenium script failed with error %v and stderr %s", err, @@ -58,10 +60,10 @@ func loginOAuthSelenium(url string, state *Client) { func stateCallback( t *testing.T, + ck *cookie.Cookie, _ FSMStateID, newState FSMStateID, data interface{}, - state *Client, ) { if newState == StateOAuthStarted { url, ok := data.(string) @@ -69,20 +71,20 @@ func stateCallback( if !ok { t.Fatalf("data is not a string for OAuth URL") } - loginOAuthSelenium(url, state) + loginOAuthSelenium(ck, url) } } func TestServer(t *testing.T) { serverURI := getServerURI(t) - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configstest", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -95,12 +97,11 @@ func TestServer(t *testing.T) { t.Fatalf("Registering error: %v", err) } - - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } - _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Connect error: %v", configErr) } @@ -112,33 +113,36 @@ func testConnectOAuthParameter( errPrefix string, ) { serverURI := getServerURI(t) - state := &Client{} configDirectory := "test_oauth_parameters" + state := &Client{} + + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", configDirectory, func(oldState FSMStateID, newState FSMStateID, data interface{}) bool { if newState == StateOAuthStarted { - server, serverErr := state.Servers.GetCustomServer(serverURI) + server, serverErr := state.Servers.CustomServer(serverURI) if serverErr != nil { t.Fatalf("No server with error: %v", serverErr) } port, portErr := server.OAuth().ListenerPort() if portErr != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Fatalf("No port with error: %v", portErr) } baseURL := fmt.Sprintf("http://127.0.0.1:%d/callback", port) p, err := url.Parse(baseURL) if err != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Fatalf("Failed to parse URL with error: %v", err) } url, err := httpw.ConstructURL(p, parameters) if err != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Fatalf( "Error: Constructing url %s with parameters %s", baseURL, @@ -148,7 +152,7 @@ func testConnectOAuthParameter( go func() { _, getErr := http.Get(url) if getErr != nil { - _ = state.CancelOAuth() + _ = ck.Cancel() t.Logf("HTTP GET error: %v", getErr) } }() @@ -165,7 +169,7 @@ func testConnectOAuthParameter( t.Fatalf("Registering error: %v", err) } - err = state.AddCustomServer(serverURI) + err = state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if errPrefix == "" { if err != nil { @@ -247,14 +251,14 @@ func TestTokenExpired(t *testing.T) { } // Get a vpn state - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configsexpired", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -267,25 +271,25 @@ func TestTokenExpired(t *testing.T) { t.Fatalf("Registering error: %v", err) } - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } - _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Connect error before expired: %v", configErr) } - currentServer, serverErr := state.Servers.GetCurrentServer() + currentServer, serverErr := state.Servers.Current() if serverErr != nil { t.Fatalf("No server found") } serverOAuth := currentServer.OAuth() - accessToken, accessTokenErr := serverOAuth.AccessToken() + accessToken, accessTokenErr := serverOAuth.AccessToken(ck.Context()) if accessTokenErr != nil { t.Fatalf("Failed to get token: %v", accessTokenErr) } @@ -293,14 +297,14 @@ func TestTokenExpired(t *testing.T) { // Wait for TTL so that the tokens expire time.Sleep(time.Duration(expiredInt) * time.Second) - _, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Connect error after expiry: %v", configErr) } // Check if tokens have changed - accessTokenAfter, accessTokenAfterErr := serverOAuth.AccessToken() + accessTokenAfter, accessTokenAfterErr := serverOAuth.AccessToken(ck.Context()) if accessTokenAfterErr != nil { t.Fatalf("Failed to get token: %v", accessTokenAfterErr) } @@ -313,14 +317,14 @@ func TestTokenExpired(t *testing.T) { // Test if an invalid profile will be corrected. func TestInvalidProfileCorrected(t *testing.T) { serverURI := getServerURI(t) - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configscancelprofile", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -333,18 +337,18 @@ func TestInvalidProfileCorrected(t *testing.T) { t.Fatalf("Registering error: %v", err) } - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } - _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("First connect error: %v", configErr) } - currentServer, serverErr := state.Servers.GetCurrentServer() + currentServer, serverErr := state.Servers.Current() if serverErr != nil { t.Fatalf("No server found") } @@ -357,7 +361,7 @@ func TestInvalidProfileCorrected(t *testing.T) { previousProfile := base.Profiles.Current base.Profiles.Current = "IDONOTEXIST" - _, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + _, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Second connect error: %v", configErr) } @@ -374,14 +378,14 @@ func TestInvalidProfileCorrected(t *testing.T) { // Test if prefer tcp is handled correctly by checking the returned config and config type. func TestPreferTCP(t *testing.T) { serverURI := getServerURI(t) - state := &Client{} - + ck := cookie.NewWithContext(context.Background()) + defer ck.Cancel() //nolint:errcheck state, err := New( "org.letsconnect-vpn.app.linux", "0.1.0-test", "configsprefertcp", func(old FSMStateID, new FSMStateID, data interface{}) bool { - stateCallback(t, old, new, data, state) + stateCallback(t, &ck, old, new, data) return true }, false, @@ -394,13 +398,13 @@ func TestPreferTCP(t *testing.T) { t.Fatalf("Registering error: %v", err) } - addErr := state.AddCustomServer(serverURI) + addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false) if addErr != nil { t.Fatalf("Add error: %v", addErr) } // get a config with preferTCP set to true - config, configErr := state.GetConfigCustomServer(serverURI, true, srvtypes.Tokens{}) + config, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, true) // Test server should accept prefer TCP! if config.Protocol != protocol.OpenVPN { @@ -417,7 +421,7 @@ func TestPreferTCP(t *testing.T) { } // get a config with preferTCP set to false - config, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{}) + config, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false) if configErr != nil { t.Fatalf("Config error: %v", configErr) } diff --git a/client/fsm.go b/client/fsm.go index 9f140e3..038e6cf 100644 --- a/client/fsm.go +++ b/client/fsm.go @@ -2,9 +2,6 @@ package client import ( "github.com/eduvpn/eduvpn-common/internal/fsm" - "github.com/eduvpn/eduvpn-common/internal/log" - "github.com/eduvpn/eduvpn-common/internal/server" - "github.com/go-errors/errors" ) type ( @@ -94,7 +91,6 @@ func newFSM( }, StateNoServer: FSMState{ Transitions: []FSMTransition{ - {To: StateNoServer, Description: "Reload list"}, {To: StateLoadingServer, Description: "User clicks a server in the UI"}, }, }, @@ -170,47 +166,3 @@ func newFSM( returnedFSM.Init(StateDeregistered, states, callback, directory, GetStateName, debug) return returnedFSM } - -// GoBack transitions the FSM back to the previous UI state, for now this is always the NO_SERVER state. -func (c *Client) GoBack() error { - if c.InFSMState(StateDeregistered) { - err := errors.Errorf("fsm attempt going back from 'StateDeregistered'") - c.logError(err) - return err - } - - // FIXME: Arbitrary back transitions don't work because we need the appropriate data - c.FSM.GoTransition(StateNoServer) - return nil -} - -// goBackInternal uses the public go back but logs an error if it happened. -func (c *Client) goBackInternal() { - err := c.GoBack() - if err != nil { - // TODO(jwijenbergh): Bit suspicious - logging level INFO, yet stacktrace logged. - log.Logger.Infof("failed going back: %s\nstacktrace:\n%s", err.Error(), err.(*errors.Error).ErrorStack()) - } -} - -// CancelOAuth cancels OAuth if one is in progress. -// If OAuth is not in progress, it returns an error. -// An error is also returned if OAuth is in progress, but it fails to cancel it. -func (c *Client) CancelOAuth() error { - if !c.InFSMState(StateOAuthStarted) { - return errors.Errorf("fsm attempt cancelling OAuth while in '%v'", c.FSM.Current) - } - - srv, err := c.Servers.GetCurrentServer() - if err != nil { - c.logError(err) - } else { - server.CancelOAuth(srv) - } - return err -} - -// InFSMState is a helper to check if the FSM is in state `checkState`. -func (c *Client) InFSMState(checkState FSMStateID) bool { - return c.FSM.InState(checkState) -} diff --git a/client/server.go b/client/server.go deleted file mode 100644 index b3f7747..0000000 --- a/client/server.go +++ /dev/null @@ -1,701 +0,0 @@ -package client - -import ( - "time" - - "github.com/eduvpn/eduvpn-common/internal/failover" - "github.com/eduvpn/eduvpn-common/internal/http" - "github.com/eduvpn/eduvpn-common/internal/log" - "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/internal/server" - discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - "github.com/eduvpn/eduvpn-common/types/protocol" - srvtypes "github.com/eduvpn/eduvpn-common/types/server" - "github.com/go-errors/errors" -) - -// TODO: This should not be reliant on an internal type -func getTokens(tok oauth.Token) srvtypes.Tokens { - return srvtypes.Tokens{ - Access: tok.Access, - Refresh: tok.Refresh, - Expires: tok.ExpiredTimestamp.Unix(), - } -} - -// getConfigAuth gets a config with authorization and authentication. -// It also asks for a profile if no valid profile is found. -func (c *Client) getConfigAuth(srv server.Server, preferTCP bool, t srvtypes.Tokens) (*srvtypes.Configuration, error) { - err := c.ensureLogin(srv, t) - if err != nil { - return nil, err - } - - // TODO(jwijenbergh): Should we check if it returns false? - c.FSM.GoTransition(StateRequestConfig) - - ok, err := server.HasValidProfile(srv, c.SupportsWireguard) - if err != nil { - return nil, err - } - - // No valid profile, ask for one - if !ok { - if err = c.askProfile(srv); err != nil { - return nil, err - } - } - - // The profile has been chosen - c.FSM.GoTransition(StateChosenProfile) - - cfg, err := server.Config(srv, c.SupportsWireguard, preferTCP) - if err != nil { - return nil, err - } - - p, err := server.CurrentProfile(srv) - if err != nil { - return nil, err - } - - pCfg := &srvtypes.Configuration{ - VPNConfig: cfg.Config, - Protocol: protocol.New(cfg.Type), - DefaultGateway: p.DefaultGateway, - Tokens: getTokens(cfg.Tokens), - } - - return pCfg, nil -} - -// retryConfigAuth retries the getConfigAuth function if the tokens are invalid. -// If OAuth is cancelled, it makes sure that we only forward the error as additional info. -func (c *Client) retryConfigAuth(srv server.Server, preferTCP bool, t srvtypes.Tokens) (*srvtypes.Configuration, error) { - cfg, err := c.getConfigAuth(srv, preferTCP, t) - if err == nil { - return cfg, nil - } - // Only retry if the error is that the tokens are invalid - tErr := &oauth.TokensInvalidError{} - if errors.As(err, &tErr) { - // TODO: Is passing empty tokens correct here? - cfg, err = c.getConfigAuth(srv, preferTCP, srvtypes.Tokens{}) - if err == nil { - return cfg, nil - } - } - c.goBackInternal() - return nil, err -} - -// getConfig gets an OpenVPN/WireGuard configuration by contacting the server, moving the FSM towards the DISCONNECTED state and then saving the local configuration file. -func (c *Client) getConfig(srv server.Server, preferTCP bool, t srvtypes.Tokens) (*srvtypes.Configuration, error) { - if c.InFSMState(StateDeregistered) { - return nil, errors.Errorf("getConfig attempt in '%v'", StateDeregistered) - } - - // Refresh the server endpoints - // This is the best effort - err := srv.RefreshEndpoints(&c.Discovery) - if err != nil { - log.Logger.Warningf("failed to refresh server endpoints: %v", err) - } - - cfg, err := c.retryConfigAuth(srv, preferTCP, t) - if err != nil { - return nil, err - } - - // Save the config - if err = c.Config.Save(&c); err != nil { - // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace... - // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well. - log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s", - err.Error(), err.(*errors.Error).ErrorStack()) - } - - c.FSM.GoTransition(StateGotConfig) - - return cfg, nil -} - -// Cleanup cleans up the VPN connection by sending a /disconnect to the server -func (c *Client) Cleanup(ct srvtypes.Tokens) error { - srv, err := c.Servers.GetCurrentServer() - if err != nil { - c.logError(err) - return err - } - err = srv.RefreshEndpoints(&c.Discovery) - if err != nil { - log.Logger.Warningf("failed to refresh server endpoints: %v", err) - } - - // If we need to relogin, update tokens - if server.NeedsRelogin(srv) { - server.UpdateTokens(srv, oauth.Token{ - Access: ct.Access, - Refresh: ct.Refresh, - ExpiredTimestamp: time.Unix(ct.Expires, 0), - }) - } - // update tokens to client - defer c.ForwardTokenUpdate(srv) - // Do the /disconnect API call - err = server.Disconnect(srv) - if err != nil { - // We log nothing here because this can happen regularly - // Maybe we should not log errors that we return directly anyways? - return err - } - // TODO: Tokens might be refreshed, return updated tokens - // Not implemented yet, because ideally we want this implemented with an interface - return nil -} - -// SetSecureLocation sets the location for the current secure location server. countryCode is the secure location to be chosen. -// This function returns an error e.g. if the server cannot be found or the location is wrong. -func (c *Client) SetSecureLocation(countryCode string) error { - if c.InFSMState(StateAskLocation) { - defer c.locationWg.Done() - } - // Not supported with Let's Connect! - if c.isLetsConnect() { - err := errors.Errorf("discovery with Let's Connect is not supported") - c.logError(err) - return err - } - - srv, err := c.Discovery.ServerByCountryCode(countryCode) - if err != nil { - c.goBackInternal() - c.logError(err) - return err - } - - if err = c.Servers.SetSecureLocation(srv); err != nil { - c.goBackInternal() - c.logError(err) - } - - return err -} - -// RemoveSecureInternet removes the current secure internet server. -// It returns an error if the server cannot be removed due to the state being DEREGISTERED. -// Note that if the server does not exist, it returns nil as an error. -func (c *Client) RemoveSecureInternet() error { - if c.InFSMState(StateDeregistered) { - err := errors.Errorf("RemoveSecureInternet attempt in '%v'", StateDeregistered) - c.logError(err) - return err - } - // No error because we can only have one secure internet server and if there are no secure internet servers, this is a NO-OP - c.Servers.RemoveSecureInternet() - c.FSM.GoTransition(StateNoServer) - // Save the config - if err := c.Config.Save(&c); err != nil { - // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace... - // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well. - log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s", - err.Error(), err.(*errors.Error).ErrorStack()) - } - return nil -} - -// RemoveInstituteAccess removes the institute access server with `url`. -// It returns an error if the server cannot be removed due to the state being DEREGISTERED. -// Note that if the server does not exist, it returns nil as an error. -func (c *Client) RemoveInstituteAccess(url string) error { - if c.InFSMState(StateDeregistered) { - err := errors.Errorf("RemoveInstituteAccess attempt in '%v'", StateDeregistered) - c.logError(err) - return err - } - // No error because this is a NO-OP if the server doesn't exist - c.Servers.RemoveInstituteAccess(url) - c.FSM.GoTransition(StateNoServer) - // Save the config - if err := c.Config.Save(&c); err != nil { - // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace... - // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well. - log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s", - err.Error(), err.(*errors.Error).ErrorStack()) - } - return nil -} - -// RemoveCustomServer removes the custom server with `url`. -// It returns an error if the server cannot be removed due to the state being DEREGISTERED. -// Note that if the server does not exist, it returns nil as an error. -func (c *Client) RemoveCustomServer(url string) error { - if c.InFSMState(StateDeregistered) { - err := errors.Errorf("RemoveCustomServer attempt in '%v'", StateDeregistered) - c.logError(err) - return err - } - // No error because this is a NO-OP if the server doesn't exist - c.Servers.RemoveCustomServer(url) - c.FSM.GoTransition(StateNoServer) - // Save the config - if err := c.Config.Save(&c); err != nil { - // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace... - // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well. - log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s", - err.Error(), err.(*errors.Error).ErrorStack()) - } - return nil -} - -// AddInstituteServer adds an Institute Access server by `url`. -func (c *Client) AddInstituteServer(url string) (err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - // Not supported with Let's Connect! - if c.isLetsConnect() { - return errors.Errorf("adding and Institute Access server with Let's Connect is not supported") - } - - // Indicate that we're loading the server - c.FSM.GoTransition(StateLoadingServer) - - // Check if we are able to fetch discovery, and log if something went wrong - if _, err := c.DiscoServers(); err != nil { - log.Logger.Warningf("Failed to get discovery servers: %v", err) - } - - if _, err := c.DiscoOrganizations(); err != nil { - log.Logger.Warningf("Failed to get discovery organizations: %v", err) - } - - // FIXME: Do nothing with discovery here as the client already has it - // So pass a server as the parameter - var dSrv *discotypes.Server - dSrv, err = c.Discovery.ServerByURL(url, "institute_access") - if err != nil { - c.goBackInternal() - return err - } - - // Add the secure internet server - srv, err := c.Servers.AddInstituteAccessServer(dSrv) - if err != nil { - c.goBackInternal() - return err - } - - // Set the server as the current so OAuth can be cancelled - if err = c.Servers.SetInstituteAccess(srv); err != nil { - c.goBackInternal() - return err - } - - // Indicate that we want to authorize this server - c.FSM.GoTransition(StateChosenServer) - - // Authorize it - if err = c.ensureLogin(srv, srvtypes.Tokens{}); err != nil { - // Removing is best effort - _ = c.RemoveInstituteAccess(url) - return err - } - - c.FSM.GoTransition(StateNoServer) - return nil -} - -// AddSecureInternetHomeServer adds a Secure Internet Home Server with `orgID` that was obtained from the Discovery file. -// Because there is only one Secure Internet Home Server, it replaces the existing one. -func (c *Client) AddSecureInternetHomeServer(orgID string) (err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - // Not supported with Let's Connect! - if c.isLetsConnect() { - return errors.Errorf("adding a secure internet server with Let's Connect is not supported") - } - - // Indicate that we're loading the server - c.FSM.GoTransition(StateLoadingServer) - - // Check if we are able to fetch discovery, and log if something went wrong - if _, err := c.DiscoServers(); err != nil { - log.Logger.Warningf("Failed to get discovery servers: %v", err) - } - - if _, err := c.DiscoOrganizations(); err != nil { - log.Logger.Warningf("Failed to get discovery organizations: %v", err) - } - - // Get the secure internet URL from discovery - org, dSrv, err := c.Discovery.SecureHomeArgs(orgID) - if err != nil { - // We mark the organizations as expired because we got an error - // Note that in the docs it states that it only should happen when the Org ID doesn't exist - // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found - c.Discovery.MarkOrganizationsExpired() - c.goBackInternal() - return err - } - - // Add the secure internet server - srv, err := c.Servers.AddSecureInternet(org, dSrv) - if err != nil { - c.goBackInternal() - return err - } - - // TODO(jwijenbergh): Does this call transfers execution flow to UI? - if err = c.askSecureLocation(); err != nil { - // Removing is the best effort - // This already goes back to the main screen - _ = c.RemoveSecureInternet() - return err - } - - c.FSM.GoTransition(StateChosenLocation) - - // Set the server as the current so OAuth can be cancelled - if err = c.Servers.SetSecureInternet(srv); err != nil { - c.goBackInternal() - return err - } - - // Server has been chosen for authentication - c.FSM.GoTransition(StateChosenServer) - - // Authorize it - if err = c.ensureLogin(srv, srvtypes.Tokens{}); err != nil { - // Removing is best effort - _ = c.RemoveSecureInternet() - return err - } - c.FSM.GoTransition(StateNoServer) - return nil -} - -// AddCustomServer adds a Custom Server by `url`. -func (c *Client) AddCustomServer(url string) (err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - if url, err = http.EnsureValidURL(url); err != nil { - return err - } - - // Indicate that we're loading the server - c.FSM.GoTransition(StateLoadingServer) - - customServer := &discotypes.Server{ - BaseURL: url, - DisplayName: map[string]string{"en": url}, - Type: "custom_server", - } - - // A custom server is just an institute access server under the hood - srv, err := c.Servers.AddCustomServer(customServer) - if err != nil { - c.goBackInternal() - return err - } - - // Set the server as the current so OAuth can be cancelled - if err = c.Servers.SetCustomServer(srv); err != nil { - c.goBackInternal() - return err - } - - // Server has been chosen for authentication - c.FSM.GoTransition(StateChosenServer) - - // Authorize it - if err = c.ensureLogin(srv, srvtypes.Tokens{}); err != nil { - // removing is best effort - _ = c.RemoveCustomServer(url) - return err - } - - c.FSM.GoTransition(StateNoServer) - return nil -} - -// GetConfigInstituteAccess gets a configuration for an Institute Access Server. -// It ensures that the Institute Access Server exists by creating or using an existing one with the url. -// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool, t srvtypes.Tokens) (cfg *srvtypes.Configuration, err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - // Not supported with Let's Connect! - if c.isLetsConnect() { - return nil, errors.Errorf("discovery with Let's Connect is not supported") - } - - c.FSM.GoTransition(StateLoadingServer) - - // Get the server if it exists - var srv *server.InstituteAccessServer - if srv, err = c.Servers.GetInstituteAccess(url); err != nil { - c.goBackInternal() - return nil, err - } - - // Set the server as the current - if err = c.Servers.SetInstituteAccess(srv); err != nil { - return nil, err - } - - // The server has now been chosen - c.FSM.GoTransition(StateChosenServer) - - if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { - c.goBackInternal() - } - - // Also forward tokens using the callback - c.ForwardTokenUpdate(srv) - - return cfg, err -} - -// GetConfigSecureInternet gets a configuration for a Secure Internet Server. -// It ensures that the Secure Internet Server exists by creating or using an existing one with the orgID. -// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -// TODO: Check on first argument orgID -func (c *Client) GetConfigSecureInternet(_ string, preferTCP bool, t srvtypes.Tokens) (cfg *srvtypes.Configuration, err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - log.Logger.Debugf("getting config for secure internet server with org ID: '%s", orgID) - - // Not supported with Let's Connect! - if c.isLetsConnect() { - return nil, errors.Errorf("discovery with Let's Connect is not supported") - } - - c.FSM.GoTransition(StateLoadingServer) - - // Get the server if it exists - var srv *server.SecureInternetHomeServer - if srv, err = c.Servers.GetSecureInternetHomeServer(); err != nil { - c.goBackInternal() - return nil, err - } - - // Set the server as the current - if err = c.Servers.SetSecureInternet(srv); err != nil { - return nil, err - } - - c.FSM.GoTransition(StateChosenServer) - - if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { - c.goBackInternal() - } - - // Also forward tokens using the callback - c.ForwardTokenUpdate(srv) - - return cfg, err -} - -// GetConfigCustomServer gets a configuration for a Custom Server. -// It ensures that the Custom Server exists by creating or using an existing one with the url. -// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (c *Client) GetConfigCustomServer(url string, preferTCP bool, t srvtypes.Tokens) (cfg *srvtypes.Configuration, err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - if url, err = http.EnsureValidURL(url); err != nil { - return nil, err - } - - c.FSM.GoTransition(StateLoadingServer) - - // Get the server if it exists - var srv *server.InstituteAccessServer - if srv, err = c.Servers.GetCustomServer(url); err != nil { - c.goBackInternal() - return nil, err - } - - // Set the server as the current - if err = c.Servers.SetCustomServer(srv); err != nil { - c.goBackInternal() - return nil, err - } - - c.FSM.GoTransition(StateChosenServer) - - if cfg, err = c.getConfig(srv, preferTCP, t); err != nil { - c.goBackInternal() - } - - // Also forward tokens using the callback - c.ForwardTokenUpdate(srv) - - return cfg, err -} - -// askSecureLocation asks the user to choose a Secure Internet location by moving the FSM to the STATE_ASK_LOCATION state. -func (c *Client) askSecureLocation() error { - loc := c.Discovery.SecureLocationList() - - c.locationWg.Add(1) - // Ask for the location in the callback - if err := c.FSM.GoTransitionRequired(StateAskLocation, loc); err != nil { - return err - } - - c.locationWg.Wait() - - // The state has changed, meaning setting the secure location was not successful - if c.FSM.Current != StateAskLocation { - log.Logger.Debugf("fsm failed to transit; expected %v / actual %v", GetStateName(StateAskLocation), GetStateName(c.FSM.Current)) - return errors.New("failed loading secure internet location") - } - return nil -} - -// RenewSession renews the session for the current VPN server. -// This logs the user back in. -func (c *Client) RenewSession() (err error) { - defer func() { - if err != nil { - c.logError(err) - } - }() - - var srv server.Server - if srv, err = c.Servers.GetCurrentServer(); err != nil { - return err - } - - err = srv.RefreshEndpoints(&c.Discovery) - if err != nil { - log.Logger.Warningf("failed to refresh server endpoints: %v", err) - } - - // The server has not been chosen yet, this means that we want to manually renew - if !c.FSM.InState(StateChosenServer) { - c.FSM.GoTransition(StateLoadingServer) - c.FSM.GoTransition(StateChosenServer) - } - - server.MarkTokensForRenew(srv) - return c.ensureLogin(srv, srvtypes.Tokens{}) -} - -// ensureLogin logs the user back in if needed. -// It runs the FSM transitions to ask for user input. -func (c *Client) ensureLogin(srv server.Server, t srvtypes.Tokens) (err error) { - // Relogin with oauth - // This moves the state to authorized - if !server.NeedsRelogin(srv) { - // OAuth was valid, ensure we are in the authorized state - c.FSM.GoTransition(StateAuthorized) - return nil - } - - // Try again but update the tokens using the client provided tokens - server.UpdateTokens(srv, oauth.Token{ - Access: t.Access, - Refresh: t.Refresh, - ExpiredTimestamp: time.Unix(t.Expires, 0), - }) - if !server.NeedsRelogin(srv) { - // OAuth was valid, ensure we are in the authorized state - c.FSM.GoTransition(StateAuthorized) - return nil - } - - // Mark organizations as expired if the server is a secure internet server - b, err := srv.Base() - // We only try to update it when we found the server base - if err == nil && b.Type == "secure_internet" { - c.Discovery.MarkOrganizationsExpired() - } - - // Tokens are not valid or the client gave an error when updating tokens - // Otherwise, do the OAuth exchange - var url string - if url, err = server.OAuthURL(srv, c.Name); err != nil { - return err - } - - if err = c.FSM.GoTransitionRequired(StateOAuthStarted, url); err != nil { - return err - } - - if err = server.OAuthExchange(srv); err != nil { - c.goBackInternal() - } - c.FSM.GoTransition(StateAuthorized) - - return err -} - -// SetProfileID sets a `profileID` for the current server. -// An error is returned if this is not possible, for example when no server is configured. -func (c *Client) SetProfileID(profileID string) (err error) { - if c.InFSMState(StateAskProfile) { - defer c.profileWg.Done() - } - defer func() { - if err != nil { - c.logError(err) - } - }() - - var srv server.Server - if srv, err = c.Servers.GetCurrentServer(); err != nil { - c.goBackInternal() - return err - } - - var b *server.Base - if b, err = srv.Base(); err != nil { - c.goBackInternal() - return err - } - b.Profiles.Current = profileID - - return nil -} - -func (c *Client) StartFailover(gateway string, wgMTU int, readRxBytes func() (int64, error)) (bool, error) { - if c.Failover != nil { - return false, errors.New("another failover process is already started") - } - c.Failover = failover.New(readRxBytes) - - return c.Failover.Start(gateway, wgMTU) -} - -func (c *Client) CancelFailover() error { - if c.Failover == nil { - return errors.New("no failover process") - } - c.Failover.Cancel() - return nil -} |
