summaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2024-02-07 13:48:30 +0100
committerJeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com>2024-02-19 14:15:07 +0100
commit16ac79c9de0dcdc2be7f5bf1c337c514ec2b757c (patch)
tree82693fa6f55168eb3957f3d683009d8195306714 /client
parent27a3ffe6d065ffe53c11f7629ba29987e39b8aeb (diff)
Client: Refactor to newest internal API
Diffstat (limited to 'client')
-rw-r--r--client/client.go864
-rw-r--r--client/client_test.go97
-rw-r--r--client/proxy.go28
-rw-r--r--client/token.go64
4 files changed, 377 insertions, 676 deletions
diff --git a/client/client.go b/client/client.go
index 4550ef0..5bcfb35 100644
--- a/client/client.go
+++ b/client/client.go
@@ -5,101 +5,94 @@ package client
import (
"context"
- "strings"
+ "errors"
"sync"
"time"
"github.com/eduvpn/eduvpn-common/i18nerr"
+ "github.com/eduvpn/eduvpn-common/internal/api"
"github.com/eduvpn/eduvpn-common/internal/config"
"github.com/eduvpn/eduvpn-common/internal/discovery"
"github.com/eduvpn/eduvpn-common/internal/failover"
"github.com/eduvpn/eduvpn-common/internal/fsm"
"github.com/eduvpn/eduvpn-common/internal/http"
"github.com/eduvpn/eduvpn-common/internal/log"
- "github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/server"
"github.com/eduvpn/eduvpn-common/types/cookie"
- discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
srvtypes "github.com/eduvpn/eduvpn-common/types/server"
- "github.com/go-errors/errors"
+ "github.com/jwijenbergh/eduoauth-go"
)
// Client is the main struct for the VPN client.
type Client struct {
// The name of the client
- Name string `json:"-"`
+ Name string
- // The chosen server
- Servers server.List `json:"servers"`
-
- // The list of servers and organizations from disco
- Discovery discovery.Discovery `json:"discovery"`
+ // The servers
+ Servers server.Servers
// The fsm
- FSM fsm.FSM `json:"-"`
-
- // The config
- Config config.Config `json:"-"`
+ FSM fsm.FSM
// Whether or not this client supports WireGuard
- SupportsWireguard bool `json:"-"`
+ SupportsWireguard bool
// Whether to enable debugging
- Debug bool `json:"-"`
+ Debug bool
// TokenSetter sets the tokens in the client
- TokenSetter func(srv srvtypes.Current, tok srvtypes.Tokens) `json:"-"`
+ TokenSetter func(sid string, stype srvtypes.Type, tok srvtypes.Tokens)
// TokenGetter gets the tokens from the client
- TokenGetter func(srv srvtypes.Current) *srvtypes.Tokens `json:"-"`
+ TokenGetter func(sid string, stype srvtypes.Type) *srvtypes.Tokens
+
+ // tokenCacher
+ tokCacher TokenCacher
+
+ // cfg is the config
+ cfg *config.Config
mu sync.Mutex
}
-func (c *Client) updateTokens(srv server.Server) error {
- if c.TokenGetter == nil {
- return errors.New("no token getter defined")
- }
- pSrv, err := c.pubCurrentServer(srv)
- if err != nil {
- return err
- }
- // shouldn't happen
- if pSrv == nil {
- return errors.New("public server is nil when getting tokens")
- }
- tokens := c.TokenGetter(*pSrv)
- if tokens == nil {
- return errors.New("client returned nil for tokens")
+func (c *Client) GettingConfig() error {
+ if c.FSM.InState(StateGettingConfig) {
+ return nil
}
-
- server.UpdateTokens(srv, oauth.Token{
- Access: tokens.Access,
- Refresh: tokens.Refresh,
- ExpiredTimestamp: time.Unix(tokens.Expires, 0),
- })
-
- return nil
+ _, err := c.FSM.GoTransition(StateGettingConfig)
+ return err
}
-func (c *Client) forwardTokens(srv server.Server) error {
- if c.TokenSetter == nil {
- return errors.New("no token setter defined")
- }
- pSrv, err := c.pubCurrentServer(srv)
+func (c *Client) InvalidProfile(ctx context.Context, srv *server.Server) (string, error) {
+ // TODO: should this have profiles as a parameter
+ ck := cookie.NewWithContext(ctx)
+
+ prfs, err := srv.Profiles(ctx)
if err != nil {
- return err
+ return "", err
}
- if pSrv == nil {
- return errors.New("public server is nil when updating tokens")
+ if !c.SupportsWireguard {
+ prfs = prfs.FilterWireGuard()
}
- o := srv.OAuth()
- if o == nil {
- return errors.New("oauth was nil when forwarding tokens")
+ // we are guaranteed to have profiles > 0 (even after filtering)
+ // because internally this callback is only triggered if there is a choice to make
+
+ errChan := make(chan error)
+ go func() {
+ err := c.FSM.GoTransitionRequired(StateAskProfile, &srvtypes.RequiredAskTransition{
+ C: ck,
+ Data: prfs.Public(),
+ })
+ if err != nil {
+ errChan <- err
+ }
+ }()
+ pID, err := ck.Receive(errChan)
+ if err != nil {
+ return "", err
}
- t := o.Token()
- c.TokenSetter(*pSrv, t.Public())
- return nil
+
+ return pID, nil
}
func (c *Client) goTransition(id fsm.StateID) error {
@@ -157,91 +150,100 @@ func New(name string, version string, directory string, stateCallback func(FSMSt
// Debug only if given
c.Debug = debug
- // Initialize the Config
- c.Config.Init(directory, "state")
-
- // Try to load the previous configuration
- if c.Config.Load(&c) != nil {
- // This error can be safely ignored, as when the config does not load, the struct will not be filled
- log.Logger.Infof("Previous configuration not found")
- }
+ c.cfg = config.NewFromDirectory(directory)
+ // set the servers
+ c.Servers = server.NewServers(c.Name, c, c.SupportsWireguard, c.cfg.V2)
return c, nil
}
-// Registering means updating the FSM to get to the initial state correctly
-func (c *Client) Register() error {
- if !c.FSM.InState(StateDeregistered) {
- return i18nerr.NewInternal("The client tried to re-initialize without deregistering first")
+func (c *Client) TriggerAuth(ctx context.Context, url string, wait bool) (string, error) {
+ // Get a reply from the client
+ if wait {
+ ck := cookie.NewWithContext(ctx)
+ errChan := make(chan error)
+ go func() {
+ err := c.FSM.GoTransitionRequired(StateOAuthStarted, &srvtypes.RequiredAskTransition{
+ C: ck,
+ Data: url,
+ })
+ if err != nil {
+ errChan <- err
+ }
+ }()
+ g, err := ck.Receive(errChan)
+ if err != nil {
+ return "", err
+ }
+ return g, nil
}
- err := c.goTransition(StateNoServer)
+ // Otherwise do normal authorization (desktop clients)
+ err := c.FSM.GoTransitionRequired(StateOAuthStarted, url)
if err != nil {
- return err
+ return "", err
}
- return nil
+ return "", nil
}
-// SaveState saves the internal state to the config
-func (c *Client) SaveState() {
- log.Logger.Debugf("saving state configuration....")
- // Save the config
- if err := c.Config.Save(&c); err != nil {
- log.Logger.Infof("failed saving state configuration: '%v'", err)
+func (c *Client) AuthDone(id string, t srvtypes.Type) {
+ srv, err := c.Servers.GetServer(id, t)
+ if err == nil {
+ srv.LastAuthorizeTime = time.Now()
+ }
+ // TODO: Should this log anything if it fails?
+ // unhandled transition?
+ _, err = c.FSM.GoTransition(StateMain)
+ if err != nil {
+ log.Logger.Debugf("unhandled auth done main transition: %v", err)
}
}
-// Deregister 'deregisters' the client, meaning saving the log file and the config and emptying out the client struct.
-func (c *Client) Deregister() {
- // First of all let's transition the state machine
- _ = c.goTransition(StateDeregistered)
-
- // SaveState saves the configuration
- c.SaveState()
-
- // Close the log file
- _ = log.Logger.Close()
-
- // Empty out the state
- *c = Client{}
-}
-
-// DiscoOrganizations gets the organizations list from the discovery server
-// If the list cannot be retrieved an error is returned.
-// If this is the case then a previous version of the list is returned if there is any.
-// This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#organization-list.
-func (c *Client) DiscoOrganizations(ck *cookie.Cookie) (orgs *discotypes.Organizations, err error) {
- // Not supported with Let's Connect! & govVPN
- if !c.hasDiscovery() {
- return nil, i18nerr.NewInternal("Server/organization discovery with this client ID is not supported")
+func (c *Client) TokensUpdated(id string, t srvtypes.Type, tok eduoauth.Token) {
+ if tok.Access == "" {
+ return
+ }
+ // Set the memory
+ err := c.tokCacher.Set(id, t, tok)
+ if err != nil {
+ log.Logger.Warningf("failed to set tokens into cache with error: %v", err)
}
- // Mark organizations as expired if we have not set an organization yet
- if !c.Servers.HasSecureInternet() {
- c.Discovery.MarkOrganizationsExpired()
+ if c.TokenSetter == nil {
+ return
}
+ // Update the client
+ c.TokenSetter(id, t, srvtypes.Tokens{
+ Access: tok.Access,
+ Refresh: tok.Refresh,
+ Expires: tok.ExpiredTimestamp.Unix(),
+ })
+}
- orgs, err = c.Discovery.Organizations(ck.Context())
+// Registering means updating the FSM to get to the initial state correctly
+func (c *Client) Register() error {
+ err := c.goTransition(StateMain)
if err != nil {
- err = i18nerr.Wrap(err, "An error occurred after getting the discovery files for the list of organizations")
+ return err
}
- return
+ return nil
}
-// DiscoServers gets the servers list from the discovery server
-// If the list cannot be retrieved an error is returned.
-// If this is the case then a previous version of the list is returned if there is any.
-// This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#server-list.
-func (c *Client) DiscoServers(ck *cookie.Cookie) (dss *discotypes.Servers, err error) {
- // Not supported with Let's Connect! & govVPN
- if !c.hasDiscovery() {
- return nil, i18nerr.NewInternal("Server/organization discovery with this client ID is not supported")
- }
+// Deregister 'deregisters' the client, meaning saving the log file and the config and emptying out the client struct.
+func (c *Client) Deregister() {
+ // save the config
+ c.TrySave()
- dss, err = c.Discovery.Servers(ck.Context())
+ // Move the state machine back
+ _, err := c.FSM.GoTransition(StateDeregistered)
if err != nil {
- err = i18nerr.Wrap(err, "An error occurred after getting the discovery files for the list of servers")
+ log.Logger.Debugf("failed deregistered transition: %v", err)
}
- return
+
+ // Close the log file
+ _ = log.Logger.Close()
+
+ // Empty out the state
+ *c = Client{}
}
// ExpiryTimes returns the different Unix timestamps regarding expiry
@@ -250,34 +252,21 @@ func (c *Client) DiscoServers(ck *cookie.Cookie) (dss *discotypes.Servers, err e
// - The list of times where notifications should be shown
// These times are reset when the VPN gets disconnected
func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) {
- // Get current expiry time
- srv, err := c.Servers.Current()
- if err != nil {
- return nil, i18nerr.Wrap(err, "The current server could not be found when getting it for expiry")
- }
- b, err := srv.Base()
+ srv, err := c.Servers.CurrentServer()
if err != nil {
- return nil, err
- }
-
- if b.StartTime.IsZero() {
- return nil, i18nerr.New("No start time is defined for this server")
+ return nil, i18nerr.Wrap(err, "The current server was not found when getting the VPN expiration date")
}
-
- bT := b.RenewButtonTime()
- cT := b.CountdownTime()
- nT := b.NotificationTimes()
return &srvtypes.Expiry{
- StartTime: b.StartTime.Unix(),
- EndTime: b.EndTime.Unix(),
- ButtonTime: bT,
- CountdownTime: cT,
- NotificationTimes: nT,
+ StartTime: srv.LastAuthorizeTime.Unix(),
+ EndTime: srv.ExpireTime.Unix(),
+ ButtonTime: server.RenewButtonTime(srv.LastAuthorizeTime, srv.ExpireTime),
+ CountdownTime: server.CountdownTime(srv.LastAuthorizeTime, srv.ExpireTime),
+ NotificationTimes: server.NotificationTimes(srv.LastAuthorizeTime, srv.ExpireTime),
}, nil
}
-func (c *Client) locationCallback(ck *cookie.Cookie) error {
- locs := c.Discovery.SecureLocationList()
+func (c *Client) locationCallback(ck *cookie.Cookie, orgID string) error {
+ locs := c.cfg.Discovery().SecureLocationList()
errChan := make(chan error)
go func() {
err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{
@@ -292,139 +281,19 @@ func (c *Client) locationCallback(ck *cookie.Cookie) error {
if err != nil {
return err
}
- err = c.SetSecureLocation(ck, loc)
- if err != nil {
- return err
- }
- err = c.goTransition(StateChosenLocation)
+ srv, err := c.Servers.GetServer(orgID, srvtypes.TypeSecureInternet)
if err != nil {
return err
}
+ srv.CountryCode = loc
return nil
}
-func (c *Client) loginCallback(ck *cookie.Cookie, srv server.Server) error {
- // get a custom redirect
- cr := CustomRedirect(c.Name)
- url, err := server.OAuthURL(srv, c.Name, cr)
- if err != nil {
- return err
- }
- authCodeURI := ""
- if cr != "" {
- errChan := make(chan error)
- go func() {
- err := c.FSM.GoTransitionRequired(StateOAuthStarted, &srvtypes.RequiredAskTransition{
- C: ck,
- Data: url,
- })
- if err != nil {
- errChan <- err
- }
- }()
- g, err := ck.Receive(errChan)
- if err != nil {
- return err
- }
- authCodeURI = g
- } else {
- err = c.FSM.GoTransitionRequired(StateOAuthStarted, url)
- if err != nil {
- return err
- }
- }
- err = server.OAuthExchange(ck.Context(), srv, authCodeURI)
- if err != nil {
- return err
- }
- return nil
-}
-
-func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool, startup bool) error {
- // location
- if srv.NeedsLocation() {
- if startup {
- return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but no secure internet location is found. Please manually connect again", server.Name(srv))
- }
- err := c.locationCallback(ck)
- if err != nil {
- return i18nerr.Wrap(err, "The secure internet location could not be set")
- }
- }
-
- err := c.goTransition(StateChosenServer)
- if err != nil {
- log.Logger.Debugf("optional chosen server transition not possible: %v", err)
- }
- // oauth
- // TODO: This should be ck.Context()
- // But needsrelogin needs a rewrite to support this properly
-
- // first make sure we get the most up to date tokens from the client
- err = c.updateTokens(srv)
- if err != nil {
- log.Logger.Debugf("failed to get tokens from client: %v", err)
- }
- if server.NeedsRelogin(context.Background(), srv) || forceauth {
- if startup {
- return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but you need to authorizate again. Please manually connect again", server.Name(srv))
- }
- // mark organizations as expired if the server is a secure internet server
- b, berr := srv.Base()
- if berr == nil && b.Type == srvtypes.TypeSecureInternet {
- c.Discovery.MarkOrganizationsExpired()
- }
- err := c.loginCallback(ck, srv)
- if err != nil {
- return i18nerr.Wrap(err, "The authorization procedure failed to complete")
- }
- }
- err = c.goTransition(StateAuthorized)
- if err != nil {
- return err
- }
-
- return nil
-}
-
-func (c *Client) profileCallback(ck *cookie.Cookie, srv server.Server, startup bool) error {
- vp, err := server.HasValidProfile(ck.Context(), srv, c.SupportsWireguard)
- if err != nil {
- log.Logger.Warningf("failed to determine whether the current protocol is valid with error: %v", err)
- return err
- }
- if !vp {
- if startup {
- return i18nerr.Newf("The client tried to autoconnect to the VPN server: %s, but no valid profiles were found. Please manually connect again", server.Name(srv))
- }
- vps, err := server.ValidProfiles(srv, c.SupportsWireguard)
- if err != nil {
- return i18nerr.Wrapf(err, "No suitable profiles could be found")
- }
- errChan := make(chan error)
- go func() {
- err := c.FSM.GoTransitionRequired(StateAskProfile, &srvtypes.RequiredAskTransition{
- C: ck,
- Data: vps.Public(),
- })
- if err != nil {
- errChan <- err
- }
- }()
- pID, err := ck.Receive(errChan)
- if err != nil {
- return i18nerr.Wrapf(err, "Profile with ID: '%s' could not be set", pID)
- }
- err = server.Profile(srv, pID)
- if err != nil {
- return i18nerr.Wrapf(err, "Profile with ID: '%s' could not be obtained from the server", pID)
- }
- }
- err = c.goTransition(StateChosenProfile)
+func (c *Client) TrySave() {
+ err := c.cfg.Save()
if err != nil {
- return err
+ log.Logger.Warningf("failed to save configuration: %v", err)
}
- return nil
}
// AddServer adds a server with identifier and type
@@ -435,463 +304,233 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
// We add the server because we can then obtain it in other callback functions
previousState := c.FSM.Current
defer func() {
- if err != nil {
- _ = c.RemoveServer(identifier, _type) //nolint:errcheck
- } else {
- c.SaveState()
- }
// If we must run callbacks, go to the previous state if we're not in it
if !ni && !c.FSM.InState(previousState) {
c.FSM.GoTransition(previousState) //nolint:errcheck
}
+ if err == nil {
+ c.TrySave()
+ }
}()
if !ni {
- err = c.goTransition(StateLoadingServer)
+ err = c.goTransition(StateAddingServer)
// this is already wrapped in an UI error
if err != nil {
return err
}
}
-
if _type != srvtypes.TypeSecureInternet {
+ // Convert to an identifier
identifier, err = http.EnsureValidURL(identifier, true)
if err != nil {
- return i18nerr.Wrap(err, "The identifier that was passed to the library is incorrect")
+ // TODO: wrap error
+ return err
}
}
- var srv server.Server
-
switch _type {
case srvtypes.TypeInstituteAccess:
- dSrv, err := c.Discovery.ServerByURL(identifier, "institute_access")
- if err != nil {
- return i18nerr.Wrapf(err, "Could not retrieve institute access server with URL: '%s' from discovery", identifier)
- }
- srv, err = c.Servers.AddInstituteAccess(ck.Context(), c.Name ,dSrv)
+ _, err = c.Servers.AddInstitute(ck.Context(), c.cfg.Discovery(), identifier, ni)
if err != nil {
return i18nerr.Wrapf(err, "The institute access server with URL: '%s' could not be added", identifier)
}
case srvtypes.TypeSecureInternet:
- dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier)
- if err != nil {
- // We mark the organizations as expired because we got an error
- // Note that in the docs it states that it only should happen when the Org ID doesn't exist
- // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found
- c.Discovery.MarkOrganizationsExpired()
- return i18nerr.Wrapf(err, "The secure internet server with organisation ID: '%s' could not be retrieved from discovery", identifier)
- }
- srv, err = c.Servers.AddSecureInternet(ck.Context(), c.Name, dOrg, dSrv)
+ _, err = c.Servers.AddSecure(ck.Context(), c.cfg.Discovery(), identifier, ni)
if err != nil {
return i18nerr.Wrapf(err, "The secure internet server with organisation ID: '%s' could not be added", identifier)
}
case srvtypes.TypeCustom:
- srv, err = c.Servers.AddCustom(ck.Context(), c.Name, identifier)
+ _, err = c.Servers.AddCustom(ck.Context(), identifier, ni)
if err != nil {
return i18nerr.Wrapf(err, "The custom server with URL: '%s' could not be added", identifier)
}
default:
return i18nerr.NewInternalf("Server type: '%v' is not valid to be added", _type)
}
-
- // if we are non interactive, we run no callbacks
- if ni {
- return nil
- }
-
- // callbacks
- err = c.callbacks(ck, srv, false, false)
- // error is already UI wrapped
- if err != nil {
- return err
- }
- terr := c.forwardTokens(srv)
- if terr != nil {
- log.Logger.Debugf("failed to forward tokens after adding: %v", terr)
- }
return nil
}
-func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool, startup bool) (cfg *srvtypes.Configuration, err error) {
- // do the callbacks to ensure valid profile, location and authorization
- err = c.callbacks(ck, srv, forceAuth, startup)
- if err != nil {
- return nil, err
- }
-
- err = c.goTransition(StateRequestConfig)
- if err != nil {
- return nil, err
- }
-
- err = c.profileCallback(ck, srv, startup)
- if err != nil {
- return nil, err
- }
-
- cfgS, err := server.Config(ck.Context(), srv, c.SupportsWireguard, pTCP)
- if err != nil {
- return nil, i18nerr.Wrap(err, "The VPN configuration could not be obtained")
+func (c *Client) convertIdentifier(identifier string, t srvtypes.Type) (string, error) {
+ // assume secure internet identifiers are always valid as we can't really assume they are valid urls (+ always https)
+ if t == srvtypes.TypeSecureInternet {
+ return identifier, nil
}
- p, err := server.CurrentProfile(srv)
+ // Convert to an identifier, this also converts the scheme to HTTPS
+ identifier, err := http.EnsureValidURL(identifier, true)
if err != nil {
- return nil, i18nerr.Wrap(err, "The current profile could not be found")
+ return "", i18nerr.Wrapf(err, "input: '%s' is not a valid URL", identifier)
}
- pcfg := cfgS.Public(p.DefaultGateway)
- return &pcfg, nil
-}
-
-func (c *Client) server(identifier string, _type srvtypes.Type) (srv server.Server, setter func(server.Server) error, err error) {
- switch _type {
- case srvtypes.TypeInstituteAccess:
- srv, err = c.Servers.InstituteAccess(identifier)
- setter = c.Servers.SetInstituteAccess
- case srvtypes.TypeSecureInternet:
- srv, err = c.Servers.SecureInternet(identifier)
- setter = c.Servers.SetSecureInternet
- case srvtypes.TypeCustom:
- srv, err = c.Servers.CustomServer(identifier)
- setter = c.Servers.SetCustom
- default:
- return nil, nil, i18nerr.NewInternalf("Not a valid server type: %v", _type)
- }
- return srv, setter, err
+ return identifier, nil
}
// GetConfig gets a VPN configuration
-func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool, startup bool) (cfg *srvtypes.Configuration, err error) {
+func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool, startup bool) (*srvtypes.Configuration, error) {
c.mu.Lock()
defer c.mu.Unlock()
previousState := c.FSM.Current
+ var err error
+
defer func() {
if err == nil {
c.FSM.GoTransition(StateGotConfig) //nolint:errcheck
- c.SaveState()
} else if !c.FSM.InState(previousState) {
// go back to the previous state if an error occurred
c.FSM.GoTransition(previousState) //nolint:errcheck
}
}()
- if _type != srvtypes.TypeSecureInternet {
- identifier, err = http.EnsureValidURL(identifier, true)
- if err != nil {
- return nil, i18nerr.Wrapf(err, "Identifier: '%s' for server with type: '%d' is not valid", identifier, _type)
- }
- }
- err = c.goTransition(StateLoadingServer)
- if err != nil {
- return nil, err
- }
- srv, set, err := c.server(identifier, _type)
- if err != nil {
- return nil, err
- }
- // refresh the server endpoints
- err = srv.RefreshEndpoints(ck.Context(), &c.Discovery)
- // If we get a canceled error, return that, otherwise just log the error
+ identifier, err = c.convertIdentifier(identifier, _type)
if err != nil {
- if errors.Is(err, context.Canceled) {
- return nil, i18nerr.Wrap(err, "The operation for getting a VPN configuration was canceled")
- }
-
- log.Logger.Warningf("failed to refresh server endpoints: %v", err)
+ return nil, i18nerr.Wrapf(err, "Server identifier: '%s', is not valid when getting a VPN configuration", identifier)
}
-
- // get a config and retry with authorization if expired
- cfg, err = c.config(ck, srv, pTCP, false, startup)
- tErr := &oauth.TokensInvalidError{}
- if err != nil && errors.As(err, &tErr) {
- log.Logger.Debugf("the tokens were invalid, trying again...")
- cfg, err = c.config(ck, srv, pTCP, true, startup)
- }
-
- // tokens might be updated, forward them
- defer func() {
- terr := c.forwardTokens(srv)
- if terr != nil {
- log.Logger.Debugf("failed to forward tokens after get config: %v", terr)
- }
- }()
-
- // still an error, return nil with the error
+ err = c.GettingConfig()
if err != nil {
- return nil, err
- }
-
- // set the current server
- if err = set(srv); err != nil {
- return nil, i18nerr.Wrapf(err, "Failed to set the server with identifier: '%s' as the current", identifier)
+ log.Logger.Debugf("failed getting config transition: %v", err)
}
- return cfg, nil
-}
-
-func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) {
- if _type != srvtypes.TypeSecureInternet {
- identifier, err = http.EnsureValidURL(identifier, true)
- if err != nil {
- return i18nerr.Wrapf(err, "Identifier: '%s' for server with type: '%d' is not valid for removal", identifier, _type)
- }
+ tok, err := c.retrieveTokens(identifier, _type)
+ if err != nil {
+ log.Logger.Debugf("no tokens found for server: '%s', with error: '%v'", identifier, err)
}
- // miscellaneous error
- var mErr error
+ var srv *server.Server
switch _type {
case srvtypes.TypeInstituteAccess:
- mErr = c.Servers.RemoveInstituteAccess(identifier)
+ srv, err = c.Servers.GetInstitute(ck.Context(), identifier, c.cfg.Discovery(), tok, startup)
case srvtypes.TypeSecureInternet:
- mErr = c.Servers.RemoveSecureInternet(identifier)
+ srv, err = c.Servers.GetSecure(ck.Context(), identifier, c.cfg.Discovery(), tok, startup)
+
+ var cErr *discovery.CountryNotFoundError
+ if errors.As(err, &cErr) {
+ err = c.locationCallback(ck, identifier)
+ if err == nil {
+ srv, err = c.Servers.GetSecure(ck.Context(), identifier, c.cfg.Discovery(), tok, startup)
+ }
+ }
case srvtypes.TypeCustom:
- mErr = c.Servers.RemoveCustom(identifier)
+ srv, err = c.Servers.GetCustom(ck.Context(), identifier, tok, startup)
default:
- return i18nerr.NewInternalf("Not a valid server type: %v", _type)
+ err = i18nerr.NewInternalf("Server type: '%v' is not valid to get a config for", _type)
}
- if mErr != nil {
- log.Logger.Debugf("failed to remove server with identifier: '%s' and type: '%d', error: %v", identifier, _type, mErr)
+ if err != nil {
+ if startup {
+ if errors.Is(err, api.ErrAuthorizeDisabled) {
+ return nil, i18nerr.Newf("The client tried to autoconnect to the VPN server: '%s', but you need to authorizate again. Please manually connect again", identifier)
+ }
+ return nil, i18nerr.Wrapf(err, "The client tried to autoconnect to the VPN server: '%s', but the operation failed to complete", identifier)
+ }
+ return nil, i18nerr.Wrapf(err, "Server: '%s' could not be obtained", identifier)
}
- c.SaveState()
- return nil
-}
-func (c *Client) CurrentServer() (*srvtypes.Current, error) {
- srv, err := c.Servers.Current()
+ cfg, err := c.Servers.ConnectWithCallbacks(ck.Context(), srv, pTCP)
if err != nil {
- return nil, err
+ return nil, i18nerr.Wrapf(err, "No VPN configuration for server: '%s' could be obtained", identifier)
}
- return c.pubCurrentServer(srv)
+ return cfg, nil
}
-func (c *Client) pubCurrentServer(srv server.Server) (*srvtypes.Current, error) {
- b, err := srv.Base()
+func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) {
+ identifier, err = c.convertIdentifier(identifier, _type)
if err != nil {
- return nil, err
+ return i18nerr.Wrapf(err, "Server identifier: '%s', is not valid when removing the server", identifier)
}
- pub, err := srv.Public()
+ err = c.Servers.Remove(identifier, _type)
if err != nil {
- return nil, err
- }
- switch t := pub.(type) {
- case *srvtypes.Server:
- if b.Type == srvtypes.TypeInstituteAccess {
- return &srvtypes.Current{
- Institute: &srvtypes.Institute{
- Server: *t,
- SupportContacts: b.SupportContact,
- // TODO: delisted
- Delisted: false,
- },
- Type: srvtypes.TypeInstituteAccess,
- }, nil
- }
- return &srvtypes.Current{
- Custom: t,
- Type: srvtypes.TypeCustom,
- }, nil
- case *srvtypes.SecureInternet:
- t.SupportContacts = b.SupportContact
- t.Locations = c.Discovery.SecureLocationList()
- return &srvtypes.Current{
- SecureInternet: t,
- Type: srvtypes.TypeSecureInternet,
- }, nil
- default:
- panic("unknown type")
+ return i18nerr.Wrapf(err, "The server: '%s' could not be removed", identifier)
}
+ return nil
}
-// TODO: This should not rely on interface{}
-func (c *Client) pubServer(srv server.Server) (interface{}, error) {
- pub, err := srv.Public()
+func (c *Client) CurrentServer() (*srvtypes.Current, error) {
+ curr, err := c.Servers.PublicCurrent(c.cfg.Discovery())
if err != nil {
- return nil, err
+ return nil, i18nerr.Wrap(err, "The current server could not be retrieved")
}
- b, err := srv.Base()
+ return curr, nil
+}
+
+func (c *Client) SetProfileID(pID string) error {
+ srv, err := c.Servers.CurrentServer()
if err != nil {
- return nil, err
- }
- switch t := pub.(type) {
- case *srvtypes.Server:
- if b.Type == srvtypes.TypeInstituteAccess {
- return &srvtypes.Institute{
- Server: *t,
- SupportContacts: b.SupportContact,
- // TODO: delisted
- Delisted: false,
- }, nil
- }
- return t, nil
- case *srvtypes.SecureInternet:
- t.SupportContacts = b.SupportContact
- t.Locations = c.Discovery.SecureLocationList()
- return t, nil
- default:
- panic("unknown type")
+ return i18nerr.Wrapf(err, "Failed to set the profile ID: '%s'", pID)
}
+ srv.Profiles.Current = pID
+ return nil
}
-func (c *Client) ServerList() (*srvtypes.List, error) {
- if c.FSM.InState(StateDeregistered) {
- return nil, i18nerr.NewInternal("Client is not registered")
- }
- var customServers []srvtypes.Server
- for _, v := range c.Servers.CustomServers.Map {
- if v == nil {
- continue
- }
- p, err := c.pubServer(v)
- if err != nil {
- continue
- }
- c, ok := p.(*srvtypes.Server)
- if !ok {
- continue
- }
- customServers = append(customServers, *c)
+func (c *Client) retrieveTokens(sid string, t srvtypes.Type) (*eduoauth.Token, error) {
+ // get from memory
+ tok, err := c.tokCacher.Get(sid, t)
+ if err == nil {
+ return tok, nil
}
- var instituteServers []srvtypes.Institute
- for _, v := range c.Servers.InstituteServers.Map {
- if v == nil {
- continue
- }
- p, err := c.pubServer(v)
- if err != nil {
- continue
- }
- i, ok := p.(*srvtypes.Institute)
- if !ok {
- continue
- }
- instituteServers = append(instituteServers, *i)
+ if c.TokenGetter == nil {
+ return tok, err
}
- var secureInternet *srvtypes.SecureInternet
- if c.Servers.HasSecureInternet() {
- srv := &c.Servers.SecureInternetHomeServer
- p, err := c.pubServer(srv)
- if err == nil {
- s, ok := p.(*srvtypes.SecureInternet)
- if ok {
- secureInternet = s
- }
- }
+ // get from client
+ gtok := c.TokenGetter(sid, t)
+ if gtok == nil {
+ return nil, errors.New("client returned nil tokens")
}
- return &srvtypes.List{
- Institutes: instituteServers,
- SecureInternet: secureInternet,
- Custom: customServers,
+ return &eduoauth.Token{
+ Access: gtok.Access,
+ Refresh: gtok.Refresh,
+ ExpiredTimestamp: time.Unix(gtok.Expires, 0),
}, nil
}
-func (c *Client) SetProfileID(pID string) (err error) {
- srv, err := c.Servers.Current()
+func (c *Client) Cleanup(ck *cookie.Cookie) error {
+ srv, err := c.Servers.CurrentServer()
if err != nil {
- return err
- }
- err = server.Profile(srv, pID)
- if err == nil {
- c.SaveState()
+ return i18nerr.Wrap(err, "The current server was not found when cleaning up the connection")
}
- return err
-}
-
-func (c *Client) Cleanup(ck *cookie.Cookie) (err error) {
- // get the current server
- srv, err := c.Servers.Current()
+ tok, err := c.retrieveTokens(srv.T.ID, srv.T.T)
if err != nil {
- return i18nerr.Wrap(err, "Failed to get the current server to cleanup the connection")
+ return i18nerr.Wrap(err, "No OAuth tokens were found when cleaning up the connection")
}
-
- err = srv.RefreshEndpoints(ck.Context(), &c.Discovery)
-
- // If we get a canceled error, return that, otherwise just log the error
+ auth, err := srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), tok, true)
if err != nil {
- if errors.Is(err, context.Canceled) {
- return i18nerr.Wrap(err, "The cleanup process was canceled")
- }
-
- log.Logger.Warningf("failed to refresh server endpoints: %v", err)
- }
-
-
- defer c.SaveState()
- err = c.updateTokens(srv)
- if err != nil {
- log.Logger.Debugf("failed to update tokens for disconnect: %v", err)
- }
- err = server.Disconnect(ck.Context(), srv)
- if err != nil {
- return i18nerr.Wrap(err, "Failed to cleanup the VPN connection for the current server")
+ return i18nerr.Wrap(err, "The server was unable to be retrieved when cleaning up the connection")
}
- err = c.forwardTokens(srv)
+ err = auth.Disconnect(ck.Context())
if err != nil {
- log.Logger.Debugf("failed to forward tokens after disconnect: %v", err)
+ return i18nerr.Wrap(err, "Failed to cleanup the VPN connection")
}
return nil
}
-func (c *Client) SetSecureLocation(ck *cookie.Cookie, countryCode string) (err error) {
+func (c *Client) SetSecureLocation(orgID string, countryCode string) error {
// not supported with Let's Connect! & govVPN
if !c.hasDiscovery() {
return i18nerr.NewInternal("Setting a secure internet location with this client ID is not supported")
}
-
- if !c.Servers.HasSecureInternet() {
- return i18nerr.Newf("No secure internet server available to set a location for")
- }
-
- dSrv, err := c.Discovery.ServerByCountryCode(countryCode)
+ srv, err := c.Servers.GetServer(orgID, srvtypes.TypeSecureInternet)
if err != nil {
- return err
- }
-
- err = c.Servers.SecureInternetHomeServer.Location(ck.Context(), dSrv)
- if err == nil {
- c.SaveState()
+ return i18nerr.Wrapf(err, "Failed to get the secure internet server with id: '%s' for setting a location", orgID)
}
- return err
+ srv.CountryCode = countryCode
+ return nil
}
-func (c *Client) RenewSession(ck *cookie.Cookie) (err error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- srv, err := c.Servers.Current()
+func (c *Client) RenewSession(ck *cookie.Cookie) error {
+ // getting the current serving with nil tokens means re-authorize
+ srv, err := c.Servers.CurrentServer()
if err != nil {
- return i18nerr.Wrap(err, "Failed to get current server for renewing the session")
- }
- // The server has not been chosen yet, this means that we want to manually renew
- // TODO: is this needed?
- if !c.FSM.InState(StateLoadingServer) {
- c.FSM.GoTransition(StateLoadingServer) //nolint:errcheck
+ return i18nerr.Wrap(err, "The current server could not be retrieved when renewing the session")
}
- err = srv.RefreshEndpoints(ck.Context(), &c.Discovery)
- // If we get a canceled error, return that, otherwise just log the error
+ // getting a server with no tokens means re-authorize
+ _, err = srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), nil, false)
if err != nil {
- if errors.Is(err, context.Canceled) {
- return i18nerr.Wrap(err, "The renewing process was canceled")
- }
-
- log.Logger.Warningf("failed to refresh server endpoints: %v", err)
+ return i18nerr.Wrap(err, "The server was unable to be retrieved when renewing the session")
}
-
-
- // update tokens in the end
- defer func() {
- terr := c.forwardTokens(srv)
- if terr != nil {
- log.Logger.Debugf("failed to forward tokens after renew: %v", terr)
- }
- }()
- defer c.SaveState()
- // TODO: Maybe this can be deleted because we force auth now
- server.MarkTokensForRenew(srv)
- // run the callbacks by forcing auth
- return c.callbacks(ck, srv, true, false)
+ return nil
}
func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readRxBytes func() (int64, error)) (bool, error) {
f := failover.New(readRxBytes)
+ // get current profile
d, err := f.Start(ck.Context(), gateway, mtu)
if err != nil {
return d, i18nerr.Wrapf(err, "Failover failed to complete with gateway: '%s' and MTU: '%d'", gateway, mtu)
@@ -899,24 +538,7 @@ func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readR
return d, nil
}
-func (c *Client) SetState(state FSMStateID) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- curr := c.FSM.Current
- _, err := c.FSM.GoTransition(state)
- if err != nil {
- // self-transitions are only debug errors
- if c.FSM.InState(state) {
- log.Logger.Debugf("attempt an invalid self-transition: %s", c.FSM.GetStateName(state))
- return nil
- }
- return i18nerr.WrapInternalf(err, "Failed internal state transition requested by the client from: '%s' to '%s'", GetStateName(curr), GetStateName(state))
- }
- return nil
-}
-
-func (c *Client) InState(state FSMStateID) bool {
- c.mu.Lock()
- defer c.mu.Unlock()
- return c.FSM.InState(state)
+func (c *Client) ServerList() (*srvtypes.List, error) {
+ g := c.cfg.V2.PublicList(c.cfg.Discovery())
+ return g, nil
}
diff --git a/client/client_test.go b/client/client_test.go
index 7d070f3..1d4bf44 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -14,6 +14,7 @@ import (
"github.com/eduvpn/eduvpn-common/types/cookie"
"github.com/eduvpn/eduvpn-common/types/protocol"
srvtypes "github.com/eduvpn/eduvpn-common/types/server"
+ "github.com/jwijenbergh/eduoauth-go"
)
func getServerURI(t *testing.T) string {
@@ -56,7 +57,6 @@ func loginOAuthSelenium(ck *cookie.Cookie, url string) {
}
func stateCallback(
- t *testing.T,
ck *cookie.Cookie,
_ FSMStateID,
newState FSMStateID,
@@ -66,7 +66,7 @@ func stateCallback(
url, ok := data.(string)
if !ok {
- t.Fatalf("data is not a string for OAuth URL")
+ panic("data is not a string for OAuth URL")
}
loginOAuthSelenium(ck, url)
}
@@ -82,7 +82,7 @@ func TestServer(t *testing.T) {
"0.1.0-test",
dir,
func(old FSMStateID, new FSMStateID, data interface{}) bool {
- stateCallback(t, &ck, old, new, data)
+ go stateCallback(ck, old, new, data)
return true
},
false,
@@ -95,11 +95,11 @@ func TestServer(t *testing.T) {
t.Fatalf("Registering error: %v", err)
}
- addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false)
+ addErr := state.AddServer(ck, serverURI, srvtypes.TypeCustom, false)
if addErr != nil {
t.Fatalf("Add error: %v", addErr)
}
- _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)
+ _, configErr := state.GetConfig(ck, serverURI, srvtypes.TypeCustom, false, false)
if configErr != nil {
t.Fatalf("Connect error: %v", configErr)
}
@@ -130,7 +130,7 @@ func TestTokenExpired(t *testing.T) {
"0.1.0-test",
dir,
func(old FSMStateID, new FSMStateID, data interface{}) bool {
- stateCallback(t, &ck, old, new, data)
+ go stateCallback(ck, old, new, data)
return true
},
false,
@@ -143,45 +143,37 @@ func TestTokenExpired(t *testing.T) {
t.Fatalf("Registering error: %v", err)
}
- addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false)
+ addErr := state.AddServer(ck, serverURI, srvtypes.TypeCustom, false)
if addErr != nil {
t.Fatalf("Add error: %v", addErr)
}
- _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)
+ _, configErr := state.GetConfig(ck, serverURI, srvtypes.TypeCustom, false, false)
if configErr != nil {
t.Fatalf("Connect error before expired: %v", configErr)
}
- currentServer, serverErr := state.Servers.Current()
- if serverErr != nil {
- t.Fatalf("No server found")
- }
-
- serverOAuth := currentServer.OAuth()
-
- accessToken, accessTokenErr := serverOAuth.AccessToken(ck.Context())
- if accessTokenErr != nil {
- t.Fatalf("Failed to get token: %v", accessTokenErr)
+ // get token before
+ tb, err := state.retrieveTokens(serverURI, srvtypes.TypeCustom)
+ if err != nil {
+ t.Fatalf("No tokens found: %v", err)
}
// Wait for TTL so that the tokens expire
time.Sleep(time.Duration(expiredInt) * time.Second)
- _, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)
+ _, configErr = state.GetConfig(ck, serverURI, srvtypes.TypeCustom, false, false)
if configErr != nil {
t.Fatalf("Connect error after expiry: %v", configErr)
}
- // Check if tokens have changed
- accessTokenAfter, accessTokenAfterErr := serverOAuth.AccessToken(ck.Context())
- if accessTokenAfterErr != nil {
- t.Fatalf("Failed to get token: %v", accessTokenAfterErr)
+ ta, err := state.retrieveTokens(serverURI, srvtypes.TypeCustom)
+ if err != nil {
+ t.Fatalf("No tokens found after: %v", err)
}
-
- if accessToken == accessTokenAfter {
+ if tb.Access == ta.Access {
t.Errorf("Access token is the same after refresh")
}
}
@@ -197,7 +189,7 @@ func TestInvalidProfileCorrected(t *testing.T) {
"0.1.0-test",
dir,
func(old FSMStateID, new FSMStateID, data interface{}) bool {
- stateCallback(t, &ck, old, new, data)
+ go stateCallback(ck, old, new, data)
return true
},
false,
@@ -210,40 +202,35 @@ func TestInvalidProfileCorrected(t *testing.T) {
t.Fatalf("Registering error: %v", err)
}
- addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false)
+ addErr := state.AddServer(ck, serverURI, srvtypes.TypeCustom, false)
if addErr != nil {
t.Fatalf("Add error: %v", addErr)
}
- _, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)
+ _, configErr := state.GetConfig(ck, serverURI, srvtypes.TypeCustom, false, false)
if configErr != nil {
t.Fatalf("First connect error: %v", configErr)
}
- currentServer, serverErr := state.Servers.Current()
- if serverErr != nil {
- t.Fatalf("No server found")
- }
-
- base, baseErr := currentServer.Base()
- if baseErr != nil {
- t.Fatalf("No base found")
+ s, err := state.Servers.CurrentServer()
+ if err != nil {
+ t.Fatalf("Got error when getting current server: %v", err)
}
+ // set invalid profile
+ invp := "IDONOTEXIST"
+ s.Profiles.Current = invp
- previousProfile := base.Profiles.Current
- base.Profiles.Current = "IDONOTEXIST"
-
- _, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)
+ _, configErr = state.GetConfig(ck, serverURI, srvtypes.TypeCustom, false, false)
if configErr != nil {
t.Fatalf("Second connect error: %v", configErr)
}
- if base.Profiles.Current != previousProfile {
+ if s.Profiles.Current == invp {
t.Fatalf(
"Profiles do no match: current %s and previous %s",
- base.Profiles.Current,
- previousProfile,
+ s.Profiles.Current,
+ invp,
)
}
}
@@ -259,7 +246,7 @@ func TestConfigStartup(t *testing.T) {
"0.1.0-test",
dir,
func(old FSMStateID, new FSMStateID, data interface{}) bool {
- stateCallback(t, &ck, old, new, data)
+ go stateCallback(ck, old, new, data)
return true
},
false,
@@ -272,14 +259,14 @@ func TestConfigStartup(t *testing.T) {
t.Fatalf("Failed to register with error: %v", err)
}
// we set true as last argument here such that no callbacks are ran
- err = state.AddServer(&ck, serverURI, srvtypes.TypeCustom, true)
+ err = state.AddServer(ck, serverURI, srvtypes.TypeCustom, true)
if err != nil {
t.Fatalf("Failed to add server for trying config startup: %v", err)
}
testTrue := func() {
// Now get config with setting startup to true
startup := true
- _, err := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, startup)
+ _, err := state.GetConfig(ck, serverURI, srvtypes.TypeCustom, false, startup)
// this should fail as we have not authorized yet/chosen profile and startup=true does not do these callbacks
if err == nil {
t.Fatal("Got no error after getting config with startup true")
@@ -290,9 +277,7 @@ func TestConfigStartup(t *testing.T) {
}
testFalse := func() {
startup := false
- // This should succeed
- _, err := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, startup)
- // this should fail as we have not authorized yet/chosen profile
+ _, err := state.GetConfig(ck, serverURI, srvtypes.TypeCustom, false, startup)
if err != nil {
t.Fatalf("Got error after getting config with startup=false: %v", err)
}
@@ -302,8 +287,10 @@ func TestConfigStartup(t *testing.T) {
// set invalid authorization and test again
// we cannot test by setting invalid profile because the server only has 1 profile
- // TODO: support multiple profiles in the test server
- state.Servers.CustomServers.Map[serverURI].OAuth().SetTokenRenew()
+ err = state.tokCacher.Set(serverURI, srvtypes.TypeCustom, eduoauth.Token{})
+ if err != nil {
+ t.Fatalf("Failed to set token cache: %v", err)
+ }
testTrue()
testFalse()
}
@@ -319,7 +306,7 @@ func TestPreferTCP(t *testing.T) {
"0.1.0-test",
dir,
func(old FSMStateID, new FSMStateID, data interface{}) bool {
- stateCallback(t, &ck, old, new, data)
+ go stateCallback(ck, old, new, data)
return true
},
false,
@@ -332,13 +319,13 @@ func TestPreferTCP(t *testing.T) {
t.Fatalf("Registering error: %v", err)
}
- addErr := state.AddServer(&ck, serverURI, srvtypes.TypeCustom, false)
+ addErr := state.AddServer(ck, serverURI, srvtypes.TypeCustom, false)
if addErr != nil {
t.Fatalf("Add error: %v", addErr)
}
// get a config with preferTCP set to true
- config, configErr := state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, true, false)
+ config, configErr := state.GetConfig(ck, serverURI, srvtypes.TypeCustom, true, false)
// Test server should accept prefer TCP!
if config.Protocol != protocol.OpenVPN {
@@ -355,7 +342,7 @@ func TestPreferTCP(t *testing.T) {
}
// get a config with preferTCP set to false
- config, configErr = state.GetConfig(&ck, serverURI, srvtypes.TypeCustom, false, false)
+ config, configErr = state.GetConfig(ck, serverURI, srvtypes.TypeCustom, false, false)
if configErr != nil {
t.Fatalf("Config error: %v", configErr)
}
diff --git a/client/proxy.go b/client/proxy.go
new file mode 100644
index 0000000..0e78792
--- /dev/null
+++ b/client/proxy.go
@@ -0,0 +1,28 @@
+package client
+
+import (
+ "codeberg.org/eduVPN/proxyguard"
+ "github.com/eduvpn/eduvpn-common/i18nerr"
+ "github.com/eduvpn/eduvpn-common/internal/log"
+ "github.com/eduvpn/eduvpn-common/types/cookie"
+)
+
+type ProxyLogger struct{}
+
+func (pl *ProxyLogger) Logf(msg string, params ...interface{}) {
+ log.Logger.Debugf(msg, params...)
+}
+
+func (pl *ProxyLogger) Log(msg string) {
+ log.Logger.Debugf("%s", msg)
+}
+
+func (c *Client) StartProxyguard(ck *cookie.Cookie, listen string, tcpsp int, peer string) error {
+ var err error
+ proxyguard.UpdateLogger(&ProxyLogger{})
+ err = proxyguard.Client(ck.Context(), listen, tcpsp, peer, -1)
+ if err != nil {
+ return i18nerr.Wrap(err, "The VPN proxy exited")
+ }
+ return err
+}
diff --git a/client/token.go b/client/token.go
new file mode 100644
index 0000000..d62308b
--- /dev/null
+++ b/client/token.go
@@ -0,0 +1,64 @@
+package client
+
+import (
+ "errors"
+ "fmt"
+
+ srvtypes "github.com/eduvpn/eduvpn-common/types/server"
+ "github.com/jwijenbergh/eduoauth-go"
+)
+
+type cacheMap map[string]eduoauth.Token
+
+type TokenCacher struct {
+ InstituteAccess cacheMap
+ CustomServer cacheMap
+ SecureInternet *eduoauth.Token
+}
+
+func (c *cacheMap) Get(id string) (*eduoauth.Token, error) {
+ if c == nil || len(*c) == 0 {
+ return nil, errors.New("no cache map available")
+ }
+ if v, ok := (*c)[id]; ok {
+ return &v, nil
+ }
+ return nil, fmt.Errorf("identifier: '%s' does not exist in token cache map", id)
+}
+
+func (tc *TokenCacher) Get(id string, t srvtypes.Type) (*eduoauth.Token, error) {
+ switch t {
+ case srvtypes.TypeCustom:
+ return tc.CustomServer.Get(id)
+ case srvtypes.TypeInstituteAccess:
+ return tc.InstituteAccess.Get(id)
+ case srvtypes.TypeSecureInternet:
+ if tc.SecureInternet == nil {
+ return nil, errors.New("no secure internet server available")
+ }
+ return tc.SecureInternet, nil
+ }
+ return nil, fmt.Errorf("invalid type for token cacher get: %d", t)
+}
+
+func (c *cacheMap) Set(id string, t eduoauth.Token) {
+ if c == nil || len(*c) == 0 {
+ *c = make(cacheMap)
+ }
+ (*c)[id] = t
+}
+
+func (tc *TokenCacher) Set(id string, t srvtypes.Type, tok eduoauth.Token) error {
+ switch t {
+ case srvtypes.TypeCustom:
+ tc.CustomServer.Set(id, tok)
+ return nil
+ case srvtypes.TypeInstituteAccess:
+ tc.InstituteAccess.Set(id, tok)
+ return nil
+ case srvtypes.TypeSecureInternet:
+ tc.SecureInternet = &tok
+ return nil
+ }
+ return fmt.Errorf("invalid type for token cacher set: %d", t)
+}