summaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
Diffstat (limited to 'client')
-rw-r--r--client/client.go622
-rw-r--r--client/client_test.go84
-rw-r--r--client/fsm.go48
-rw-r--r--client/server.go701
4 files changed, 513 insertions, 942 deletions
diff --git a/client/client.go b/client/client.go
index 70adb71..813f6dc 100644
--- a/client/client.go
+++ b/client/client.go
@@ -2,9 +2,9 @@
package client
import (
+ "context"
"fmt"
"strings"
- "sync"
"github.com/eduvpn/eduvpn-common/internal/config"
"github.com/eduvpn/eduvpn-common/internal/discovery"
@@ -14,32 +14,12 @@ import (
"github.com/eduvpn/eduvpn-common/internal/log"
"github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/server"
+ "github.com/eduvpn/eduvpn-common/types/cookie"
discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
- "github.com/eduvpn/eduvpn-common/types/protocol"
srvtypes "github.com/eduvpn/eduvpn-common/types/server"
"github.com/go-errors/errors"
)
-type (
- // ServerBase is an alias to the internal ServerBase
- // This contains the details for each server.
- ServerBase = server.Base
-)
-
-func (c *Client) logError(err error) {
- // Logs the error with the same level/verbosity as the error
- if c.Debug {
- log.Logger.Inherit(err, fmt.Sprintf("\nwith stacktrace: %s\n", err.(*errors.Error).ErrorStack()))
- } else {
- log.Logger.Inherit(err, "")
- }
-}
-
-func (c *Client) isLetsConnect() bool {
- // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php
- return strings.HasPrefix(c.Name, "org.letsconnect-vpn.app")
-}
-
// isAllowedClientID checks if the 'clientID' is in the list of allowed client IDs
func isAllowedClientID(clientID string) bool {
allowList := []string{
@@ -91,13 +71,27 @@ func userAgentName(clientID string) string {
}
}
+func (c *Client) logError(err error) {
+ // Logs the error with the same level/verbosity as the error
+ if c.Debug {
+ log.Logger.Inherit(err, fmt.Sprintf("\nwith stacktrace: %s\n", err.(*errors.Error).ErrorStack()))
+ } else {
+ log.Logger.Inherit(err, "")
+ }
+}
+
+func (c *Client) isLetsConnect() bool {
+ // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php
+ return strings.HasPrefix(c.Name, "org.letsconnect-vpn.app")
+}
+
// Client is the main struct for the VPN client.
type Client struct {
// The name of the client
Name string `json:"-"`
// The chosen server
- Servers server.Servers `json:"servers"`
+ Servers server.List `json:"servers"`
// The list of servers and organizations from disco
Discovery discovery.Discovery `json:"discovery"`
@@ -116,9 +110,6 @@ type Client struct {
// The Failover monitor for the current VPN connection
Failover *failover.DroppedConMon
-
- locationWg sync.WaitGroup
- profileWg sync.WaitGroup
}
// New creates a new client with the following parameters:
@@ -179,7 +170,7 @@ func New(name string, version string, directory string, stateCallback func(FSMSt
// Registering means updating the FSM to get to the initial state correctly
func (c *Client) Register() error {
- if !c.InFSMState(StateDeregistered) {
+ if !c.FSM.InState(StateDeregistered) {
return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current)
}
c.FSM.GoTransition(StateNoServer)
@@ -200,27 +191,11 @@ func (c *Client) Deregister() {
*c = Client{}
}
-// askProfile asks the user for a profile by moving the FSM to the ASK_PROFILE state.
-func (c *Client) askProfile(srv server.Server) error {
- ps, err := server.ValidProfiles(srv, c.SupportsWireguard)
- if err != nil {
- return err
- }
-
- c.profileWg.Add(1)
- if err = c.FSM.GoTransitionRequired(StateAskProfile, convertProfiles(*ps)); err != nil {
- return err
- }
- c.profileWg.Wait()
-
- return nil
-}
-
// DiscoOrganizations gets the organizations list from the discovery server
// If the list cannot be retrieved an error is returned.
// If this is the case then a previous version of the list is returned if there is any.
// This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#organization-list.
-func (c *Client) DiscoOrganizations() (orgs *discotypes.Organizations, err error) {
+func (c *Client) DiscoOrganizations(ck *cookie.Cookie) (orgs *discotypes.Organizations, err error) {
defer func() {
if err != nil {
c.logError(err)
@@ -237,14 +212,15 @@ func (c *Client) DiscoOrganizations() (orgs *discotypes.Organizations, err error
c.Discovery.MarkOrganizationsExpired()
}
- return c.Discovery.Organizations()
+ // TODO: pass a context
+ return c.Discovery.Organizations(ck.Context())
}
// DiscoServers gets the servers list from the discovery server
// If the list cannot be retrieved an error is returned.
// If this is the case then a previous version of the list is returned if there is any.
// This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#server-list.
-func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) {
+func (c *Client) DiscoServers(ck *cookie.Cookie) (dss *discotypes.Servers, err error) {
defer func() {
if err != nil {
c.logError(err)
@@ -256,7 +232,8 @@ func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) {
return nil, errors.Errorf("discovery with Let's Connect is not supported")
}
- return c.Discovery.Servers()
+ // TODO: pass a context
+ return c.Discovery.Servers(ck.Context())
}
// ExpiryTimes returns the different Unix timestamps regarding expiry
@@ -266,7 +243,7 @@ func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) {
// These times are reset when the VPN gets disconnected
func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) {
// Get current expiry time
- srv, err := c.Servers.GetCurrentServer()
+ srv, err := c.Servers.Current()
if err != nil {
c.logError(err)
return nil, err
@@ -293,149 +270,488 @@ func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) {
}, nil
}
-func convertProfiles(profiles server.ProfileInfo) srvtypes.Profiles {
- m := make(map[string]srvtypes.Profile)
- for _, p := range profiles.Info.ProfileList {
- var protocols []protocol.Protocol
- // loop through all protocol strings
- for _, ps := range p.VPNProtoList {
- protocols = append(protocols, protocol.New(ps))
- }
- m[p.ID] = srvtypes.Profile{
- DisplayName: map[string]string{
- "en": p.DisplayName,
- },
- Protocols: protocols,
+func (c *Client) locationCallback(ck *cookie.Cookie) error {
+ locs := c.Discovery.SecureLocationList()
+ errChan := make(chan error)
+ go func() {
+ err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{
+ C: ck,
+ Data: locs,
+ })
+ if err != nil {
+ errChan <- err
}
+ }()
+ loc, err := ck.Receive(errChan)
+ if err != nil {
+ return err
}
- return srvtypes.Profiles{Map: m, Current: profiles.Current}
+ err = c.SetSecureLocation(ck, loc)
+ if err != nil {
+ return err
+ }
+ t := c.FSM.GoTransition(StateChosenLocation)
+ if !t {
+ log.Logger.Warningf("transition chosen location not completed")
+ }
+ return nil
}
-func convertGeneric(server server.InstituteAccessServer) (*srvtypes.Server, error) {
- b, err := server.Base()
+func (c *Client) loginCallback(ck *cookie.Cookie, srv server.Server) error {
+ url, err := server.OAuthURL(srv, c.Name)
if err != nil {
- return nil, err
+ return err
}
- return &srvtypes.Server{
- DisplayName: b.DisplayName,
- Identifier: b.URL,
- Profiles: convertProfiles(b.Profiles),
- }, nil
+ err = c.FSM.GoTransitionRequired(StateOAuthStarted, url)
+ if err != nil {
+ return err
+ }
+ err = server.OAuthExchange(ck.Context(), srv)
+ if err != nil {
+ return err
+ }
+ return nil
}
-// TODO: CLEAN THIS UP
-func (c *Client) ServerList() (*srvtypes.List, error) {
- custom := c.Servers.CustomServers
- var customServers []srvtypes.Server
- for _, v := range custom.Map {
- if v == nil {
- return nil, errors.New("found nil value in custom server map")
+func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool) error {
+ // location
+ if srv.NeedsLocation() {
+ err := c.locationCallback(ck)
+ if err != nil {
+ return err
}
- conv, err := convertGeneric(*v)
+ }
+
+ t := c.FSM.GoTransition(StateChosenServer)
+ if !t {
+ log.Logger.Warningf("transition not completed for chosen server")
+ }
+ // oauth
+ // TODO: This should be ck.Context()
+ // But needsrelogin needs a rewrite to support this properly
+ if server.NeedsRelogin(context.Background(), srv) || forceauth {
+ err := c.loginCallback(ck, srv)
if err != nil {
- return nil, errors.Errorf("failed to convert custom server for public type: %v", err)
+ return err
}
- customServers = append(customServers, *conv)
}
- institute := c.Servers.InstituteServers
- var instituteServers []srvtypes.Institute
- for _, v := range institute.Map {
- if v == nil {
- return nil, errors.New("found nil value in institute server map")
+ t = c.FSM.GoTransition(StateAuthorized)
+ if !t {
+ log.Logger.Warningf("transition authorized not completed")
+ }
+
+ return nil
+}
+
+func (c *Client) profileCallback(ck *cookie.Cookie, srv server.Server) error {
+ vp, err := server.HasValidProfile(ck.Context(), srv, c.SupportsWireguard)
+ if err != nil {
+ return err
+ }
+ if !vp {
+ b, err := srv.Base()
+ if err != nil {
+ return err
}
- conv, err := convertGeneric(*v)
+ ps := b.Profiles.Public()
+ errChan := make(chan error)
+ go func() {
+ err := c.FSM.GoTransitionRequired(StateAskProfile, &srvtypes.RequiredAskTransition{
+ C: ck,
+ Data: ps,
+ })
+ if err != nil {
+ errChan <- err
+ }
+ }()
+ pID, err := ck.Receive(errChan)
if err != nil {
- return nil, errors.Errorf("failed to convert institute server for public type: %v", err)
+ return err
+ }
+ err = server.Profile(srv, pID)
+ if err != nil {
+ return err
}
- instituteServers = append(instituteServers, srvtypes.Institute{
- Server: *conv,
- // TODO: delisted
- Delisted: false,
- })
}
+ t := c.FSM.GoTransition(StateChosenProfile)
+ if !t {
+ log.Logger.Warningf("transition chosen profile not completed")
+ }
+ return nil
+}
- var secureInternet *srvtypes.SecureInternet
- if c.Servers.HasSecureInternet() {
- b, err := c.Servers.SecureInternetHomeServer.Base()
- if err == nil {
- generic := srvtypes.Server{
- DisplayName: b.DisplayName,
- Identifier: b.URL,
- Profiles: convertProfiles(b.Profiles),
- }
- cc := c.Servers.SecureInternetHomeServer.CurrentLocation
- secureInternet = &srvtypes.SecureInternet{
- Server: generic,
- CountryCode: cc,
- // TODO: delisted
- Delisted: false,
- }
+// AddServer adds a server with identifier and type
+func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.Type, ni bool) (err error) {
+ // If we have failed to add the server, we remove it again
+ // We add the server because we can then obtain it in other callback functions
+ defer func() {
+ if err != nil {
+ _ = c.RemoveServer(identifier, _type) //nolint:errcheck
}
+ c.FSM.GoTransition(StateNoServer)
+ }()
+ if !ni {
+ if !c.FSM.InState(StateNoServer) {
+ return errors.Errorf("wrong state to add a server: %s", GetStateName(c.FSM.Current))
+ }
+ t := c.FSM.GoTransition(StateLoadingServer)
+ if !t {
+ log.Logger.Warningf("transition not completed for loading server")
+ }
}
- return &srvtypes.List{
- Institutes: instituteServers,
- SecureInternet: secureInternet,
- Custom: customServers,
- }, nil
-}
-// TODO: CLEAN THIS UP
-func (c *Client) CurrentServer() (*srvtypes.Current, error) {
- srvs := c.Servers
+ identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet)
+ if err != nil {
+ return err
+ }
+
+ var srv server.Server
- switch srvs.IsType {
- case server.InstituteAccessServerType:
- curr, err := srvs.GetInstituteAccess(srvs.InstituteServers.CurrentURL)
+ switch _type {
+ case srvtypes.TypeInstituteAccess:
+ dSrv, err := c.Discovery.ServerByURL(identifier, "institute_access")
if err != nil {
- return nil, err
+ return err
}
- conv, err := convertGeneric(*curr)
+ srv, err = c.Servers.AddInstituteAccess(ck.Context(), dSrv)
if err != nil {
- return nil, err
+ return err
}
- return &srvtypes.Current{
- Institute: &srvtypes.Institute{
- Server: *conv,
- // TODO: delisted
- Delisted: false,
- },
- Type: srvtypes.TypeInstituteAccess,
- }, nil
- case server.CustomServerType:
- curr, err := srvs.GetCustomServer(srvs.CustomServers.CurrentURL)
+ case srvtypes.TypeSecureInternet:
+ dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier)
if err != nil {
- return nil, err
+ return err
}
- conv, err := convertGeneric(*curr)
+ srv, err = c.Servers.AddSecureInternet(ck.Context(), dOrg, dSrv)
if err != nil {
- return nil, err
+ return err
+ }
+ case srvtypes.TypeCustom:
+ srv, err = c.Servers.AddCustom(ck.Context(), identifier)
+ if err != nil {
+ return err
+ }
+ default:
+ return errors.Errorf("not a valid server type: %v", _type)
+ }
+
+ // if we are non interactive, we run no callbacks
+ if ni {
+ return nil
+ }
+
+ // callbacks
+ return c.callbacks(ck, srv, false)
+}
+
+func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool) (cfg *srvtypes.Configuration, err error) {
+ // do the callbacks to ensure valid profile, location and authorization
+ err = c.callbacks(ck, srv, forceAuth)
+ if err != nil {
+ return nil, err
+ }
+
+ t := c.FSM.GoTransition(StateRequestConfig)
+ if !t {
+ log.Logger.Warningf("transition not completed for requesting config")
+ }
+
+ err = c.profileCallback(ck, srv)
+ if err != nil {
+ return nil, err
+ }
+
+ cfgS, err := server.Config(ck.Context(), srv, c.SupportsWireguard, pTCP)
+ if err != nil {
+ return nil, err
+ }
+ p, err := server.CurrentProfile(srv)
+ if err != nil {
+ return nil, err
+ }
+ pcfg := cfgS.Public(p.DefaultGateway)
+ if err != nil {
+ return nil, err
+ }
+ return &pcfg, nil
+}
+
+func (c *Client) server(identifier string, _type srvtypes.Type) (srv server.Server, setter func(server.Server) error, err error) {
+ switch _type {
+ case srvtypes.TypeInstituteAccess:
+ srv, err = c.Servers.InstituteAccess(identifier)
+ setter = c.Servers.SetInstituteAccess
+ case srvtypes.TypeSecureInternet:
+ srv, err = c.Servers.SecureInternet(identifier)
+ setter = c.Servers.SetSecureInternet
+ case srvtypes.TypeCustom:
+ srv, err = c.Servers.CustomServer(identifier)
+ setter = c.Servers.SetCustom
+ default:
+ return nil, nil, errors.Errorf("not a valid server type: %v", _type)
+ }
+ return srv, setter, err
+}
+
+// GetConfig gets a VPN configuration
+func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool) (cfg *srvtypes.Configuration, err error) {
+ defer func() {
+ if err == nil {
+ c.FSM.GoTransition(StateGotConfig)
+ } else {
+ // go back if an error occurred
+ c.FSM.GoTransition(StateNoServer)
+ }
+ }()
+ identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet)
+ if err != nil {
+ return nil, err
+ }
+ t := c.FSM.GoTransition(StateLoadingServer)
+ if !t {
+ log.Logger.Warningf("transition not completed for loading server")
+ }
+ srv, set, err := c.server(identifier, _type)
+ if err != nil {
+ return nil, err
+ }
+ // refresh the server endpoints
+ err = server.RefreshEndpoints(ck.Context(), srv)
+ if err != nil {
+ log.Logger.Warningf("failed to refresh server endpoints: %v", err)
+ }
+
+ // get a config and retry with authorization if expired
+ cfg, err = c.config(ck, srv, pTCP, false)
+ tErr := &oauth.TokensInvalidError{}
+ if err != nil && errors.As(err, &tErr) {
+ cfg, err = c.config(ck, srv, pTCP, true)
+ }
+
+ // still an error, return nil with the error
+ if err != nil {
+ return nil, err
+ }
+
+ // set the current server
+ if err = set(srv); err != nil {
+ return nil, err
+ }
+
+ return cfg, nil
+}
+
+func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) {
+ identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet)
+ if err != nil {
+ return err
+ }
+ switch _type {
+ case srvtypes.TypeInstituteAccess:
+ return c.Servers.RemoveInstituteAccess(identifier)
+ case srvtypes.TypeSecureInternet:
+ return c.Servers.RemoveSecureInternet(identifier)
+ case srvtypes.TypeCustom:
+ return c.Servers.RemoveCustom(identifier)
+ default:
+ return errors.Errorf("not a valid server type: %v", _type)
+ }
+}
+
+func (c *Client) CurrentServer() (*srvtypes.Current, error) {
+ if !c.FSM.InState(StateGotConfig) {
+ return nil, errors.Errorf("State: %s, cannot have a current server. Did you get a VPN configuration?", GetStateName(c.FSM.Current))
+ }
+ srv, err := c.Servers.Current()
+ if err != nil {
+ return nil, err
+ }
+ return c.pubCurrentServer(srv)
+}
+
+func (c *Client) pubCurrentServer(srv server.Server) (*srvtypes.Current, error) {
+ b, err := srv.Base()
+ if err != nil {
+ return nil, err
+ }
+ pub, err := srv.Public()
+ if err != nil {
+ return nil, err
+ }
+ switch t := pub.(type) {
+ case *srvtypes.Server:
+ if b.Type == srvtypes.TypeInstituteAccess {
+ return &srvtypes.Current{
+ Institute: &srvtypes.Institute{
+ Server: *t,
+ // TODO: delisted
+ Delisted: false,
+ },
+ Type: srvtypes.TypeInstituteAccess,
+ }, nil
}
return &srvtypes.Current{
- Custom: conv,
+ Custom: t,
Type: srvtypes.TypeCustom,
}, nil
- case server.SecureInternetServerType:
- b, err := c.Servers.SecureInternetHomeServer.Base()
+ case *srvtypes.SecureInternet:
+ t.Locations = c.Discovery.SecureLocationList()
+ return &srvtypes.Current{
+ SecureInternet: t,
+ Type: srvtypes.TypeSecureInternet,
+ }, nil
+ default:
+ panic("unknown type")
+ }
+}
+
+// TODO: This should not rely on interface{}
+func (c *Client) pubServer(srv server.Server) (interface{}, error) {
+ pub, err := srv.Public()
+ if err != nil {
+ return nil, err
+ }
+ switch t := pub.(type) {
+ case *srvtypes.Server:
+ b, err := srv.Base()
if err != nil {
return nil, err
}
- generic := srvtypes.Server{
- DisplayName: b.DisplayName,
- Identifier: c.Servers.SecureInternetHomeServer.HomeOrganizationID,
- Profiles: convertProfiles(b.Profiles),
- }
- cc := c.Servers.SecureInternetHomeServer.CurrentLocation
- return &srvtypes.Current{
- SecureInternet: &srvtypes.SecureInternet{
- Server: generic,
- CountryCode: cc,
+ if b.Type == srvtypes.TypeInstituteAccess {
+ return &srvtypes.Institute{
+ Server: *t,
// TODO: delisted
Delisted: false,
- },
- Type: srvtypes.TypeSecureInternet,
- }, nil
+ }, nil
+ }
+ return t, nil
+ case *srvtypes.SecureInternet:
+ t.Locations = c.Discovery.SecureLocationList()
+ return t, nil
default:
- return nil, errors.New("current server not found")
+ panic("unknown type")
+ }
+}
+
+func (c *Client) ServerList() (*srvtypes.List, error) {
+ if c.FSM.InState(StateDeregistered) {
+ return nil, errors.New("client is not registered")
+ }
+ var customServers []srvtypes.Server
+ for _, v := range c.Servers.CustomServers.Map {
+ if v == nil {
+ continue
+ }
+ p, err := c.pubServer(v)
+ if err != nil {
+ continue
+ }
+ c, ok := p.(*srvtypes.Server)
+ if !ok {
+ continue
+ }
+ customServers = append(customServers, *c)
}
+ var instituteServers []srvtypes.Institute
+ for _, v := range c.Servers.CustomServers.Map {
+ if v == nil {
+ continue
+ }
+ p, err := c.pubServer(v)
+ if err != nil {
+ continue
+ }
+ i, ok := p.(*srvtypes.Institute)
+ if !ok {
+ continue
+ }
+ instituteServers = append(instituteServers, *i)
+ }
+ var secureInternet *srvtypes.SecureInternet
+ if c.Servers.HasSecureInternet() {
+ srv := &c.Servers.SecureInternetHomeServer
+ p, err := c.pubServer(srv)
+ if err == nil {
+ s, ok := p.(*srvtypes.SecureInternet)
+ if ok {
+ secureInternet = s
+ }
+ }
+ }
+ return &srvtypes.List{
+ Institutes: instituteServers,
+ SecureInternet: secureInternet,
+ Custom: customServers,
+ }, nil
+}
+
+func (c *Client) SetProfileID(pID string) (err error) {
+ srv, err := c.Servers.Current()
+ if err != nil {
+ return err
+ }
+ return server.Profile(srv, pID)
+}
+
+func (c *Client) Cleanup(ck *cookie.Cookie) (err error) {
+ // get the current server
+ srv, err := c.Servers.Current()
+ if err != nil {
+ return err
+ }
+ // TODO: Support cookie context here
+ // if server.NeedsRelogin(context.Background(), srv) {
+ // // TODO: ask client for tokens
+ // }
+ err = server.Disconnect(ck.Context(), srv)
+ if err != nil {
+ return err
+ }
+ // TODO: Set tokens with callback
+ return nil
+}
+
+func (c *Client) SetSecureLocation(ck *cookie.Cookie, countryCode string) (err error) {
+ if c.isLetsConnect() {
+ return errors.Errorf("setting a secure internet location with Let's Connect! is not supported")
+ }
+
+ if !c.Servers.HasSecureInternet() {
+ return errors.Errorf("no secure internet server available to set a location for")
+ }
+
+ dSrv, err := c.Discovery.ServerByCountryCode(countryCode)
+ if err != nil {
+ return err
+ }
+
+ return c.Servers.SecureInternetHomeServer.Location(ck.Context(), dSrv)
+}
+
+func (c *Client) RenewSession(ck *cookie.Cookie) (err error) {
+ srv, err := c.Servers.Current()
+ if err != nil {
+ return err
+ }
+ // The server has not been chosen yet, this means that we want to manually renew
+ // TODO: is this needed?
+ if !c.FSM.InState(StateChosenServer) {
+ c.FSM.GoTransition(StateLoadingServer)
+ c.FSM.GoTransition(StateChosenServer)
+ }
+ // TODO: Maybe this can be deleted because we force auth now
+ server.MarkTokensForRenew(srv)
+ // run the callbacks by forcing auth
+ return c.callbacks(ck, srv, true)
+}
+
+func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, wgMTU int, readRxBytes func() (int64, error)) (bool, error) {
+ if c.Failover != nil {
+ return false, errors.New("another failover process is already started")
+ }
+
+ c.Failover = failover.New(readRxBytes)
+
+ return c.Failover.Start(ck.Context(), gateway, wgMTU)
}
diff --git a/client/client_test.go b/client/client_test.go
index 56c38ff..7077ce4 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -1,6 +1,7 @@
package client
import (
+ "context"
"fmt"
"net/http"
"net/url"
@@ -12,6 +13,7 @@ import (
"time"
httpw "github.com/eduvpn/eduvpn-common/internal/http"
+ "github.com/eduvpn/eduvpn-common/types/cookie"
"github.com/eduvpn/eduvpn-common/types/protocol"
srvtypes "github.com/eduvpn/eduvpn-common/types/server"
"github.com/go-errors/errors"
@@ -22,7 +24,7 @@ func getServerURI(t *testing.T) string {
if serverURI == "" {
t.Skip("Skipping server test as no SERVER_URI env var has been passed")
}
- serverURI, parseErr := httpw.EnsureValidURL(serverURI)
+ serverURI, parseErr := httpw.EnsureValidURL(serverURI, true)
if parseErr != nil {
t.Skip("Skipping server test as the server uri is not valid")
}
@@ -41,13 +43,13 @@ func runCommand(errBuffer *strings.Builder, name string, args ...string) error {
return cmd.Wait()
}
-func loginOAuthSelenium(url string, state *Client) {
+func loginOAuthSelenium(ck *cookie.Cookie, url string) {
// We could use the go selenium library
// But it does not support the latest selenium v4 just yet
var errBuffer strings.Builder
err := runCommand(&errBuffer, "python3", "../selenium_eduvpn.py", url)
if err != nil {
- _ = state.CancelOAuth()
+ _ = ck.Cancel()
panic(fmt.Sprintf(
"Login OAuth with selenium script failed with error %v and stderr %s",
err,
@@ -58,10 +60,10 @@ func loginOAuthSelenium(url string, state *Client) {
func stateCallback(
t *testing.T,
+ ck *cookie.Cookie,
_ FSMStateID,
newState FSMStateID,
data interface{},
- state *Client,
) {
if newState == StateOAuthStarted {
url, ok := data.(string)
@@ -69,20 +71,20 @@ func stateCallback(
if !ok {
t.Fatalf("data is not a string for OAuth URL")
}
- loginOAuthSelenium(url, state)
+ loginOAuthSelenium(ck, url)
}
}
func TestServer(t *testing.T) {
serverURI := getServerURI(t)
- state := &Client{}
-
+ ck := cookie.NewWithContext(context.Background())
+ defer ck.Cancel() //nolint:errcheck
state, err := New(
"org.letsconnect-vpn.app.linux",
"0.1.0-test",
"configstest",
func(old FSMStateID, new FSMStateID, data interface{}) bool {
- stateCallback(t, old, new, data, state)
+ stateCallback(t, &ck, old, new, data)
return true
},
false,
@@ -95,12 +97,11 @@ func TestServer(t *testing.T) {
t.Fatalf("Registering error: %v", err)
}
-
- addErr := state.AddCustomServer(serverURI)
+ addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false)
if addErr != nil {
t.Fatalf("Add error: %v", addErr)
}
- _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{})
+ _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
if configErr != nil {
t.Fatalf("Connect error: %v", configErr)
}
@@ -112,33 +113,36 @@ func testConnectOAuthParameter(
errPrefix string,
) {
serverURI := getServerURI(t)
- state := &Client{}
configDirectory := "test_oauth_parameters"
+ state := &Client{}
+
+ ck := cookie.NewWithContext(context.Background())
+ defer ck.Cancel() //nolint:errcheck
state, err := New(
"org.letsconnect-vpn.app.linux",
"0.1.0-test",
configDirectory,
func(oldState FSMStateID, newState FSMStateID, data interface{}) bool {
if newState == StateOAuthStarted {
- server, serverErr := state.Servers.GetCustomServer(serverURI)
+ server, serverErr := state.Servers.CustomServer(serverURI)
if serverErr != nil {
t.Fatalf("No server with error: %v", serverErr)
}
port, portErr := server.OAuth().ListenerPort()
if portErr != nil {
- _ = state.CancelOAuth()
+ _ = ck.Cancel()
t.Fatalf("No port with error: %v", portErr)
}
baseURL := fmt.Sprintf("http://127.0.0.1:%d/callback", port)
p, err := url.Parse(baseURL)
if err != nil {
- _ = state.CancelOAuth()
+ _ = ck.Cancel()
t.Fatalf("Failed to parse URL with error: %v", err)
}
url, err := httpw.ConstructURL(p, parameters)
if err != nil {
- _ = state.CancelOAuth()
+ _ = ck.Cancel()
t.Fatalf(
"Error: Constructing url %s with parameters %s",
baseURL,
@@ -148,7 +152,7 @@ func testConnectOAuthParameter(
go func() {
_, getErr := http.Get(url)
if getErr != nil {
- _ = state.CancelOAuth()
+ _ = ck.Cancel()
t.Logf("HTTP GET error: %v", getErr)
}
}()
@@ -165,7 +169,7 @@ func testConnectOAuthParameter(
t.Fatalf("Registering error: %v", err)
}
- err = state.AddCustomServer(serverURI)
+ err = state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false)
if errPrefix == "" {
if err != nil {
@@ -247,14 +251,14 @@ func TestTokenExpired(t *testing.T) {
}
// Get a vpn state
- state := &Client{}
-
+ ck := cookie.NewWithContext(context.Background())
+ defer ck.Cancel() //nolint:errcheck
state, err := New(
"org.letsconnect-vpn.app.linux",
"0.1.0-test",
"configsexpired",
func(old FSMStateID, new FSMStateID, data interface{}) bool {
- stateCallback(t, old, new, data, state)
+ stateCallback(t, &ck, old, new, data)
return true
},
false,
@@ -267,25 +271,25 @@ func TestTokenExpired(t *testing.T) {
t.Fatalf("Registering error: %v", err)
}
- addErr := state.AddCustomServer(serverURI)
+ addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false)
if addErr != nil {
t.Fatalf("Add error: %v", addErr)
}
- _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{})
+ _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
if configErr != nil {
t.Fatalf("Connect error before expired: %v", configErr)
}
- currentServer, serverErr := state.Servers.GetCurrentServer()
+ currentServer, serverErr := state.Servers.Current()
if serverErr != nil {
t.Fatalf("No server found")
}
serverOAuth := currentServer.OAuth()
- accessToken, accessTokenErr := serverOAuth.AccessToken()
+ accessToken, accessTokenErr := serverOAuth.AccessToken(ck.Context())
if accessTokenErr != nil {
t.Fatalf("Failed to get token: %v", accessTokenErr)
}
@@ -293,14 +297,14 @@ func TestTokenExpired(t *testing.T) {
// Wait for TTL so that the tokens expire
time.Sleep(time.Duration(expiredInt) * time.Second)
- _, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{})
+ _, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
if configErr != nil {
t.Fatalf("Connect error after expiry: %v", configErr)
}
// Check if tokens have changed
- accessTokenAfter, accessTokenAfterErr := serverOAuth.AccessToken()
+ accessTokenAfter, accessTokenAfterErr := serverOAuth.AccessToken(ck.Context())
if accessTokenAfterErr != nil {
t.Fatalf("Failed to get token: %v", accessTokenAfterErr)
}
@@ -313,14 +317,14 @@ func TestTokenExpired(t *testing.T) {
// Test if an invalid profile will be corrected.
func TestInvalidProfileCorrected(t *testing.T) {
serverURI := getServerURI(t)
- state := &Client{}
-
+ ck := cookie.NewWithContext(context.Background())
+ defer ck.Cancel() //nolint:errcheck
state, err := New(
"org.letsconnect-vpn.app.linux",
"0.1.0-test",
"configscancelprofile",
func(old FSMStateID, new FSMStateID, data interface{}) bool {
- stateCallback(t, old, new, data, state)
+ stateCallback(t, &ck, old, new, data)
return true
},
false,
@@ -333,18 +337,18 @@ func TestInvalidProfileCorrected(t *testing.T) {
t.Fatalf("Registering error: %v", err)
}
- addErr := state.AddCustomServer(serverURI)
+ addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false)
if addErr != nil {
t.Fatalf("Add error: %v", addErr)
}
- _, configErr := state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{})
+ _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
if configErr != nil {
t.Fatalf("First connect error: %v", configErr)
}
- currentServer, serverErr := state.Servers.GetCurrentServer()
+ currentServer, serverErr := state.Servers.Current()
if serverErr != nil {
t.Fatalf("No server found")
}
@@ -357,7 +361,7 @@ func TestInvalidProfileCorrected(t *testing.T) {
previousProfile := base.Profiles.Current
base.Profiles.Current = "IDONOTEXIST"
- _, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{})
+ _, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
if configErr != nil {
t.Fatalf("Second connect error: %v", configErr)
}
@@ -374,14 +378,14 @@ func TestInvalidProfileCorrected(t *testing.T) {
// Test if prefer tcp is handled correctly by checking the returned config and config type.
func TestPreferTCP(t *testing.T) {
serverURI := getServerURI(t)
- state := &Client{}
-
+ ck := cookie.NewWithContext(context.Background())
+ defer ck.Cancel() //nolint:errcheck
state, err := New(
"org.letsconnect-vpn.app.linux",
"0.1.0-test",
"configsprefertcp",
func(old FSMStateID, new FSMStateID, data interface{}) bool {
- stateCallback(t, old, new, data, state)
+ stateCallback(t, &ck, old, new, data)
return true
},
false,
@@ -394,13 +398,13 @@ func TestPreferTCP(t *testing.T) {
t.Fatalf("Registering error: %v", err)
}
- addErr := state.AddCustomServer(serverURI)
+ addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false)
if addErr != nil {
t.Fatalf("Add error: %v", addErr)
}
// get a config with preferTCP set to true
- config, configErr := state.GetConfigCustomServer(serverURI, true, srvtypes.Tokens{})
+ config, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, true)
// Test server should accept prefer TCP!
if config.Protocol != protocol.OpenVPN {
@@ -417,7 +421,7 @@ func TestPreferTCP(t *testing.T) {
}
// get a config with preferTCP set to false
- config, configErr = state.GetConfigCustomServer(serverURI, false, srvtypes.Tokens{})
+ config, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false)
if configErr != nil {
t.Fatalf("Config error: %v", configErr)
}
diff --git a/client/fsm.go b/client/fsm.go
index 9f140e3..038e6cf 100644
--- a/client/fsm.go
+++ b/client/fsm.go
@@ -2,9 +2,6 @@ package client
import (
"github.com/eduvpn/eduvpn-common/internal/fsm"
- "github.com/eduvpn/eduvpn-common/internal/log"
- "github.com/eduvpn/eduvpn-common/internal/server"
- "github.com/go-errors/errors"
)
type (
@@ -94,7 +91,6 @@ func newFSM(
},
StateNoServer: FSMState{
Transitions: []FSMTransition{
- {To: StateNoServer, Description: "Reload list"},
{To: StateLoadingServer, Description: "User clicks a server in the UI"},
},
},
@@ -170,47 +166,3 @@ func newFSM(
returnedFSM.Init(StateDeregistered, states, callback, directory, GetStateName, debug)
return returnedFSM
}
-
-// GoBack transitions the FSM back to the previous UI state, for now this is always the NO_SERVER state.
-func (c *Client) GoBack() error {
- if c.InFSMState(StateDeregistered) {
- err := errors.Errorf("fsm attempt going back from 'StateDeregistered'")
- c.logError(err)
- return err
- }
-
- // FIXME: Arbitrary back transitions don't work because we need the appropriate data
- c.FSM.GoTransition(StateNoServer)
- return nil
-}
-
-// goBackInternal uses the public go back but logs an error if it happened.
-func (c *Client) goBackInternal() {
- err := c.GoBack()
- if err != nil {
- // TODO(jwijenbergh): Bit suspicious - logging level INFO, yet stacktrace logged.
- log.Logger.Infof("failed going back: %s\nstacktrace:\n%s", err.Error(), err.(*errors.Error).ErrorStack())
- }
-}
-
-// CancelOAuth cancels OAuth if one is in progress.
-// If OAuth is not in progress, it returns an error.
-// An error is also returned if OAuth is in progress, but it fails to cancel it.
-func (c *Client) CancelOAuth() error {
- if !c.InFSMState(StateOAuthStarted) {
- return errors.Errorf("fsm attempt cancelling OAuth while in '%v'", c.FSM.Current)
- }
-
- srv, err := c.Servers.GetCurrentServer()
- if err != nil {
- c.logError(err)
- } else {
- server.CancelOAuth(srv)
- }
- return err
-}
-
-// InFSMState is a helper to check if the FSM is in state `checkState`.
-func (c *Client) InFSMState(checkState FSMStateID) bool {
- return c.FSM.InState(checkState)
-}
diff --git a/client/server.go b/client/server.go
deleted file mode 100644
index b3f7747..0000000
--- a/client/server.go
+++ /dev/null
@@ -1,701 +0,0 @@
-package client
-
-import (
- "time"
-
- "github.com/eduvpn/eduvpn-common/internal/failover"
- "github.com/eduvpn/eduvpn-common/internal/http"
- "github.com/eduvpn/eduvpn-common/internal/log"
- "github.com/eduvpn/eduvpn-common/internal/oauth"
- "github.com/eduvpn/eduvpn-common/internal/server"
- discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
- "github.com/eduvpn/eduvpn-common/types/protocol"
- srvtypes "github.com/eduvpn/eduvpn-common/types/server"
- "github.com/go-errors/errors"
-)
-
-// TODO: This should not be reliant on an internal type
-func getTokens(tok oauth.Token) srvtypes.Tokens {
- return srvtypes.Tokens{
- Access: tok.Access,
- Refresh: tok.Refresh,
- Expires: tok.ExpiredTimestamp.Unix(),
- }
-}
-
-// getConfigAuth gets a config with authorization and authentication.
-// It also asks for a profile if no valid profile is found.
-func (c *Client) getConfigAuth(srv server.Server, preferTCP bool, t srvtypes.Tokens) (*srvtypes.Configuration, error) {
- err := c.ensureLogin(srv, t)
- if err != nil {
- return nil, err
- }
-
- // TODO(jwijenbergh): Should we check if it returns false?
- c.FSM.GoTransition(StateRequestConfig)
-
- ok, err := server.HasValidProfile(srv, c.SupportsWireguard)
- if err != nil {
- return nil, err
- }
-
- // No valid profile, ask for one
- if !ok {
- if err = c.askProfile(srv); err != nil {
- return nil, err
- }
- }
-
- // The profile has been chosen
- c.FSM.GoTransition(StateChosenProfile)
-
- cfg, err := server.Config(srv, c.SupportsWireguard, preferTCP)
- if err != nil {
- return nil, err
- }
-
- p, err := server.CurrentProfile(srv)
- if err != nil {
- return nil, err
- }
-
- pCfg := &srvtypes.Configuration{
- VPNConfig: cfg.Config,
- Protocol: protocol.New(cfg.Type),
- DefaultGateway: p.DefaultGateway,
- Tokens: getTokens(cfg.Tokens),
- }
-
- return pCfg, nil
-}
-
-// retryConfigAuth retries the getConfigAuth function if the tokens are invalid.
-// If OAuth is cancelled, it makes sure that we only forward the error as additional info.
-func (c *Client) retryConfigAuth(srv server.Server, preferTCP bool, t srvtypes.Tokens) (*srvtypes.Configuration, error) {
- cfg, err := c.getConfigAuth(srv, preferTCP, t)
- if err == nil {
- return cfg, nil
- }
- // Only retry if the error is that the tokens are invalid
- tErr := &oauth.TokensInvalidError{}
- if errors.As(err, &tErr) {
- // TODO: Is passing empty tokens correct here?
- cfg, err = c.getConfigAuth(srv, preferTCP, srvtypes.Tokens{})
- if err == nil {
- return cfg, nil
- }
- }
- c.goBackInternal()
- return nil, err
-}
-
-// getConfig gets an OpenVPN/WireGuard configuration by contacting the server, moving the FSM towards the DISCONNECTED state and then saving the local configuration file.
-func (c *Client) getConfig(srv server.Server, preferTCP bool, t srvtypes.Tokens) (*srvtypes.Configuration, error) {
- if c.InFSMState(StateDeregistered) {
- return nil, errors.Errorf("getConfig attempt in '%v'", StateDeregistered)
- }
-
- // Refresh the server endpoints
- // This is the best effort
- err := srv.RefreshEndpoints(&c.Discovery)
- if err != nil {
- log.Logger.Warningf("failed to refresh server endpoints: %v", err)
- }
-
- cfg, err := c.retryConfigAuth(srv, preferTCP, t)
- if err != nil {
- return nil, err
- }
-
- // Save the config
- if err = c.Config.Save(&c); err != nil {
- // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace...
- // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well.
- log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s",
- err.Error(), err.(*errors.Error).ErrorStack())
- }
-
- c.FSM.GoTransition(StateGotConfig)
-
- return cfg, nil
-}
-
-// Cleanup cleans up the VPN connection by sending a /disconnect to the server
-func (c *Client) Cleanup(ct srvtypes.Tokens) error {
- srv, err := c.Servers.GetCurrentServer()
- if err != nil {
- c.logError(err)
- return err
- }
- err = srv.RefreshEndpoints(&c.Discovery)
- if err != nil {
- log.Logger.Warningf("failed to refresh server endpoints: %v", err)
- }
-
- // If we need to relogin, update tokens
- if server.NeedsRelogin(srv) {
- server.UpdateTokens(srv, oauth.Token{
- Access: ct.Access,
- Refresh: ct.Refresh,
- ExpiredTimestamp: time.Unix(ct.Expires, 0),
- })
- }
- // update tokens to client
- defer c.ForwardTokenUpdate(srv)
- // Do the /disconnect API call
- err = server.Disconnect(srv)
- if err != nil {
- // We log nothing here because this can happen regularly
- // Maybe we should not log errors that we return directly anyways?
- return err
- }
- // TODO: Tokens might be refreshed, return updated tokens
- // Not implemented yet, because ideally we want this implemented with an interface
- return nil
-}
-
-// SetSecureLocation sets the location for the current secure location server. countryCode is the secure location to be chosen.
-// This function returns an error e.g. if the server cannot be found or the location is wrong.
-func (c *Client) SetSecureLocation(countryCode string) error {
- if c.InFSMState(StateAskLocation) {
- defer c.locationWg.Done()
- }
- // Not supported with Let's Connect!
- if c.isLetsConnect() {
- err := errors.Errorf("discovery with Let's Connect is not supported")
- c.logError(err)
- return err
- }
-
- srv, err := c.Discovery.ServerByCountryCode(countryCode)
- if err != nil {
- c.goBackInternal()
- c.logError(err)
- return err
- }
-
- if err = c.Servers.SetSecureLocation(srv); err != nil {
- c.goBackInternal()
- c.logError(err)
- }
-
- return err
-}
-
-// RemoveSecureInternet removes the current secure internet server.
-// It returns an error if the server cannot be removed due to the state being DEREGISTERED.
-// Note that if the server does not exist, it returns nil as an error.
-func (c *Client) RemoveSecureInternet() error {
- if c.InFSMState(StateDeregistered) {
- err := errors.Errorf("RemoveSecureInternet attempt in '%v'", StateDeregistered)
- c.logError(err)
- return err
- }
- // No error because we can only have one secure internet server and if there are no secure internet servers, this is a NO-OP
- c.Servers.RemoveSecureInternet()
- c.FSM.GoTransition(StateNoServer)
- // Save the config
- if err := c.Config.Save(&c); err != nil {
- // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace...
- // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well.
- log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s",
- err.Error(), err.(*errors.Error).ErrorStack())
- }
- return nil
-}
-
-// RemoveInstituteAccess removes the institute access server with `url`.
-// It returns an error if the server cannot be removed due to the state being DEREGISTERED.
-// Note that if the server does not exist, it returns nil as an error.
-func (c *Client) RemoveInstituteAccess(url string) error {
- if c.InFSMState(StateDeregistered) {
- err := errors.Errorf("RemoveInstituteAccess attempt in '%v'", StateDeregistered)
- c.logError(err)
- return err
- }
- // No error because this is a NO-OP if the server doesn't exist
- c.Servers.RemoveInstituteAccess(url)
- c.FSM.GoTransition(StateNoServer)
- // Save the config
- if err := c.Config.Save(&c); err != nil {
- // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace...
- // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well.
- log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s",
- err.Error(), err.(*errors.Error).ErrorStack())
- }
- return nil
-}
-
-// RemoveCustomServer removes the custom server with `url`.
-// It returns an error if the server cannot be removed due to the state being DEREGISTERED.
-// Note that if the server does not exist, it returns nil as an error.
-func (c *Client) RemoveCustomServer(url string) error {
- if c.InFSMState(StateDeregistered) {
- err := errors.Errorf("RemoveCustomServer attempt in '%v'", StateDeregistered)
- c.logError(err)
- return err
- }
- // No error because this is a NO-OP if the server doesn't exist
- c.Servers.RemoveCustomServer(url)
- c.FSM.GoTransition(StateNoServer)
- // Save the config
- if err := c.Config.Save(&c); err != nil {
- // TODO(jwijenbergh): Not sure why INFO level, yet stacktrace...
- // TODO(jwijenbergh): Even worse, why logging it but then return nil? The calling code will think that everything went well.
- log.Logger.Infof("c.Config.Save failed: %s\nstacktrace:\n%s",
- err.Error(), err.(*errors.Error).ErrorStack())
- }
- return nil
-}
-
-// AddInstituteServer adds an Institute Access server by `url`.
-func (c *Client) AddInstituteServer(url string) (err error) {
- defer func() {
- if err != nil {
- c.logError(err)
- }
- }()
-
- // Not supported with Let's Connect!
- if c.isLetsConnect() {
- return errors.Errorf("adding and Institute Access server with Let's Connect is not supported")
- }
-
- // Indicate that we're loading the server
- c.FSM.GoTransition(StateLoadingServer)
-
- // Check if we are able to fetch discovery, and log if something went wrong
- if _, err := c.DiscoServers(); err != nil {
- log.Logger.Warningf("Failed to get discovery servers: %v", err)
- }
-
- if _, err := c.DiscoOrganizations(); err != nil {
- log.Logger.Warningf("Failed to get discovery organizations: %v", err)
- }
-
- // FIXME: Do nothing with discovery here as the client already has it
- // So pass a server as the parameter
- var dSrv *discotypes.Server
- dSrv, err = c.Discovery.ServerByURL(url, "institute_access")
- if err != nil {
- c.goBackInternal()
- return err
- }
-
- // Add the secure internet server
- srv, err := c.Servers.AddInstituteAccessServer(dSrv)
- if err != nil {
- c.goBackInternal()
- return err
- }
-
- // Set the server as the current so OAuth can be cancelled
- if err = c.Servers.SetInstituteAccess(srv); err != nil {
- c.goBackInternal()
- return err
- }
-
- // Indicate that we want to authorize this server
- c.FSM.GoTransition(StateChosenServer)
-
- // Authorize it
- if err = c.ensureLogin(srv, srvtypes.Tokens{}); err != nil {
- // Removing is best effort
- _ = c.RemoveInstituteAccess(url)
- return err
- }
-
- c.FSM.GoTransition(StateNoServer)
- return nil
-}
-
-// AddSecureInternetHomeServer adds a Secure Internet Home Server with `orgID` that was obtained from the Discovery file.
-// Because there is only one Secure Internet Home Server, it replaces the existing one.
-func (c *Client) AddSecureInternetHomeServer(orgID string) (err error) {
- defer func() {
- if err != nil {
- c.logError(err)
- }
- }()
-
- // Not supported with Let's Connect!
- if c.isLetsConnect() {
- return errors.Errorf("adding a secure internet server with Let's Connect is not supported")
- }
-
- // Indicate that we're loading the server
- c.FSM.GoTransition(StateLoadingServer)
-
- // Check if we are able to fetch discovery, and log if something went wrong
- if _, err := c.DiscoServers(); err != nil {
- log.Logger.Warningf("Failed to get discovery servers: %v", err)
- }
-
- if _, err := c.DiscoOrganizations(); err != nil {
- log.Logger.Warningf("Failed to get discovery organizations: %v", err)
- }
-
- // Get the secure internet URL from discovery
- org, dSrv, err := c.Discovery.SecureHomeArgs(orgID)
- if err != nil {
- // We mark the organizations as expired because we got an error
- // Note that in the docs it states that it only should happen when the Org ID doesn't exist
- // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found
- c.Discovery.MarkOrganizationsExpired()
- c.goBackInternal()
- return err
- }
-
- // Add the secure internet server
- srv, err := c.Servers.AddSecureInternet(org, dSrv)
- if err != nil {
- c.goBackInternal()
- return err
- }
-
- // TODO(jwijenbergh): Does this call transfers execution flow to UI?
- if err = c.askSecureLocation(); err != nil {
- // Removing is the best effort
- // This already goes back to the main screen
- _ = c.RemoveSecureInternet()
- return err
- }
-
- c.FSM.GoTransition(StateChosenLocation)
-
- // Set the server as the current so OAuth can be cancelled
- if err = c.Servers.SetSecureInternet(srv); err != nil {
- c.goBackInternal()
- return err
- }
-
- // Server has been chosen for authentication
- c.FSM.GoTransition(StateChosenServer)
-
- // Authorize it
- if err = c.ensureLogin(srv, srvtypes.Tokens{}); err != nil {
- // Removing is best effort
- _ = c.RemoveSecureInternet()
- return err
- }
- c.FSM.GoTransition(StateNoServer)
- return nil
-}
-
-// AddCustomServer adds a Custom Server by `url`.
-func (c *Client) AddCustomServer(url string) (err error) {
- defer func() {
- if err != nil {
- c.logError(err)
- }
- }()
-
- if url, err = http.EnsureValidURL(url); err != nil {
- return err
- }
-
- // Indicate that we're loading the server
- c.FSM.GoTransition(StateLoadingServer)
-
- customServer := &discotypes.Server{
- BaseURL: url,
- DisplayName: map[string]string{"en": url},
- Type: "custom_server",
- }
-
- // A custom server is just an institute access server under the hood
- srv, err := c.Servers.AddCustomServer(customServer)
- if err != nil {
- c.goBackInternal()
- return err
- }
-
- // Set the server as the current so OAuth can be cancelled
- if err = c.Servers.SetCustomServer(srv); err != nil {
- c.goBackInternal()
- return err
- }
-
- // Server has been chosen for authentication
- c.FSM.GoTransition(StateChosenServer)
-
- // Authorize it
- if err = c.ensureLogin(srv, srvtypes.Tokens{}); err != nil {
- // removing is best effort
- _ = c.RemoveCustomServer(url)
- return err
- }
-
- c.FSM.GoTransition(StateNoServer)
- return nil
-}
-
-// GetConfigInstituteAccess gets a configuration for an Institute Access Server.
-// It ensures that the Institute Access Server exists by creating or using an existing one with the url.
-// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel.
-func (c *Client) GetConfigInstituteAccess(url string, preferTCP bool, t srvtypes.Tokens) (cfg *srvtypes.Configuration, err error) {
- defer func() {
- if err != nil {
- c.logError(err)
- }
- }()
-
- // Not supported with Let's Connect!
- if c.isLetsConnect() {
- return nil, errors.Errorf("discovery with Let's Connect is not supported")
- }
-
- c.FSM.GoTransition(StateLoadingServer)
-
- // Get the server if it exists
- var srv *server.InstituteAccessServer
- if srv, err = c.Servers.GetInstituteAccess(url); err != nil {
- c.goBackInternal()
- return nil, err
- }
-
- // Set the server as the current
- if err = c.Servers.SetInstituteAccess(srv); err != nil {
- return nil, err
- }
-
- // The server has now been chosen
- c.FSM.GoTransition(StateChosenServer)
-
- if cfg, err = c.getConfig(srv, preferTCP, t); err != nil {
- c.goBackInternal()
- }
-
- // Also forward tokens using the callback
- c.ForwardTokenUpdate(srv)
-
- return cfg, err
-}
-
-// GetConfigSecureInternet gets a configuration for a Secure Internet Server.
-// It ensures that the Secure Internet Server exists by creating or using an existing one with the orgID.
-// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel.
-// TODO: Check on first argument orgID
-func (c *Client) GetConfigSecureInternet(_ string, preferTCP bool, t srvtypes.Tokens) (cfg *srvtypes.Configuration, err error) {
- defer func() {
- if err != nil {
- c.logError(err)
- }
- }()
-
- log.Logger.Debugf("getting config for secure internet server with org ID: '%s", orgID)
-
- // Not supported with Let's Connect!
- if c.isLetsConnect() {
- return nil, errors.Errorf("discovery with Let's Connect is not supported")
- }
-
- c.FSM.GoTransition(StateLoadingServer)
-
- // Get the server if it exists
- var srv *server.SecureInternetHomeServer
- if srv, err = c.Servers.GetSecureInternetHomeServer(); err != nil {
- c.goBackInternal()
- return nil, err
- }
-
- // Set the server as the current
- if err = c.Servers.SetSecureInternet(srv); err != nil {
- return nil, err
- }
-
- c.FSM.GoTransition(StateChosenServer)
-
- if cfg, err = c.getConfig(srv, preferTCP, t); err != nil {
- c.goBackInternal()
- }
-
- // Also forward tokens using the callback
- c.ForwardTokenUpdate(srv)
-
- return cfg, err
-}
-
-// GetConfigCustomServer gets a configuration for a Custom Server.
-// It ensures that the Custom Server exists by creating or using an existing one with the url.
-// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel.
-func (c *Client) GetConfigCustomServer(url string, preferTCP bool, t srvtypes.Tokens) (cfg *srvtypes.Configuration, err error) {
- defer func() {
- if err != nil {
- c.logError(err)
- }
- }()
-
- if url, err = http.EnsureValidURL(url); err != nil {
- return nil, err
- }
-
- c.FSM.GoTransition(StateLoadingServer)
-
- // Get the server if it exists
- var srv *server.InstituteAccessServer
- if srv, err = c.Servers.GetCustomServer(url); err != nil {
- c.goBackInternal()
- return nil, err
- }
-
- // Set the server as the current
- if err = c.Servers.SetCustomServer(srv); err != nil {
- c.goBackInternal()
- return nil, err
- }
-
- c.FSM.GoTransition(StateChosenServer)
-
- if cfg, err = c.getConfig(srv, preferTCP, t); err != nil {
- c.goBackInternal()
- }
-
- // Also forward tokens using the callback
- c.ForwardTokenUpdate(srv)
-
- return cfg, err
-}
-
-// askSecureLocation asks the user to choose a Secure Internet location by moving the FSM to the STATE_ASK_LOCATION state.
-func (c *Client) askSecureLocation() error {
- loc := c.Discovery.SecureLocationList()
-
- c.locationWg.Add(1)
- // Ask for the location in the callback
- if err := c.FSM.GoTransitionRequired(StateAskLocation, loc); err != nil {
- return err
- }
-
- c.locationWg.Wait()
-
- // The state has changed, meaning setting the secure location was not successful
- if c.FSM.Current != StateAskLocation {
- log.Logger.Debugf("fsm failed to transit; expected %v / actual %v", GetStateName(StateAskLocation), GetStateName(c.FSM.Current))
- return errors.New("failed loading secure internet location")
- }
- return nil
-}
-
-// RenewSession renews the session for the current VPN server.
-// This logs the user back in.
-func (c *Client) RenewSession() (err error) {
- defer func() {
- if err != nil {
- c.logError(err)
- }
- }()
-
- var srv server.Server
- if srv, err = c.Servers.GetCurrentServer(); err != nil {
- return err
- }
-
- err = srv.RefreshEndpoints(&c.Discovery)
- if err != nil {
- log.Logger.Warningf("failed to refresh server endpoints: %v", err)
- }
-
- // The server has not been chosen yet, this means that we want to manually renew
- if !c.FSM.InState(StateChosenServer) {
- c.FSM.GoTransition(StateLoadingServer)
- c.FSM.GoTransition(StateChosenServer)
- }
-
- server.MarkTokensForRenew(srv)
- return c.ensureLogin(srv, srvtypes.Tokens{})
-}
-
-// ensureLogin logs the user back in if needed.
-// It runs the FSM transitions to ask for user input.
-func (c *Client) ensureLogin(srv server.Server, t srvtypes.Tokens) (err error) {
- // Relogin with oauth
- // This moves the state to authorized
- if !server.NeedsRelogin(srv) {
- // OAuth was valid, ensure we are in the authorized state
- c.FSM.GoTransition(StateAuthorized)
- return nil
- }
-
- // Try again but update the tokens using the client provided tokens
- server.UpdateTokens(srv, oauth.Token{
- Access: t.Access,
- Refresh: t.Refresh,
- ExpiredTimestamp: time.Unix(t.Expires, 0),
- })
- if !server.NeedsRelogin(srv) {
- // OAuth was valid, ensure we are in the authorized state
- c.FSM.GoTransition(StateAuthorized)
- return nil
- }
-
- // Mark organizations as expired if the server is a secure internet server
- b, err := srv.Base()
- // We only try to update it when we found the server base
- if err == nil && b.Type == "secure_internet" {
- c.Discovery.MarkOrganizationsExpired()
- }
-
- // Tokens are not valid or the client gave an error when updating tokens
- // Otherwise, do the OAuth exchange
- var url string
- if url, err = server.OAuthURL(srv, c.Name); err != nil {
- return err
- }
-
- if err = c.FSM.GoTransitionRequired(StateOAuthStarted, url); err != nil {
- return err
- }
-
- if err = server.OAuthExchange(srv); err != nil {
- c.goBackInternal()
- }
- c.FSM.GoTransition(StateAuthorized)
-
- return err
-}
-
-// SetProfileID sets a `profileID` for the current server.
-// An error is returned if this is not possible, for example when no server is configured.
-func (c *Client) SetProfileID(profileID string) (err error) {
- if c.InFSMState(StateAskProfile) {
- defer c.profileWg.Done()
- }
- defer func() {
- if err != nil {
- c.logError(err)
- }
- }()
-
- var srv server.Server
- if srv, err = c.Servers.GetCurrentServer(); err != nil {
- c.goBackInternal()
- return err
- }
-
- var b *server.Base
- if b, err = srv.Base(); err != nil {
- c.goBackInternal()
- return err
- }
- b.Profiles.Current = profileID
-
- return nil
-}
-
-func (c *Client) StartFailover(gateway string, wgMTU int, readRxBytes func() (int64, error)) (bool, error) {
- if c.Failover != nil {
- return false, errors.New("another failover process is already started")
- }
- c.Failover = failover.New(readRxBytes)
-
- return c.Failover.Start(gateway, wgMTU)
-}
-
-func (c *Client) CancelFailover() error {
- if c.Failover == nil {
- return errors.New("no failover process")
- }
- c.Failover.Cancel()
- return nil
-}