diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2023-04-12 22:55:16 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2023-09-25 09:43:37 +0200 |
| commit | a38e3e79f74e95051db7e14ae14ab817b68b725a (patch) | |
| tree | e26cab53f993d2d845020f81ee6f6f6a8a12ded1 /client/client.go | |
| parent | 4d228ba2084eb810d0cc33308893b00d1bb3eb02 (diff) | |
Refactor: Move client implementation to one file
Much easier to oversee and it forces me to keep the client type as
small as possible. This also uses the cookie for cancellation
We also no longer require tokens inside arguments. We will later
implement them with callbacks
Diffstat (limited to 'client/client.go')
| -rw-r--r-- | client/client.go | 622 |
1 files changed, 469 insertions, 153 deletions
diff --git a/client/client.go b/client/client.go index 70adb71..813f6dc 100644 --- a/client/client.go +++ b/client/client.go @@ -2,9 +2,9 @@ package client import ( + "context" "fmt" "strings" - "sync" "github.com/eduvpn/eduvpn-common/internal/config" "github.com/eduvpn/eduvpn-common/internal/discovery" @@ -14,32 +14,12 @@ import ( "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" + "github.com/eduvpn/eduvpn-common/types/cookie" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - "github.com/eduvpn/eduvpn-common/types/protocol" srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" ) -type ( - // ServerBase is an alias to the internal ServerBase - // This contains the details for each server. - ServerBase = server.Base -) - -func (c *Client) logError(err error) { - // Logs the error with the same level/verbosity as the error - if c.Debug { - log.Logger.Inherit(err, fmt.Sprintf("\nwith stacktrace: %s\n", err.(*errors.Error).ErrorStack())) - } else { - log.Logger.Inherit(err, "") - } -} - -func (c *Client) isLetsConnect() bool { - // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php - return strings.HasPrefix(c.Name, "org.letsconnect-vpn.app") -} - // isAllowedClientID checks if the 'clientID' is in the list of allowed client IDs func isAllowedClientID(clientID string) bool { allowList := []string{ @@ -91,13 +71,27 @@ func userAgentName(clientID string) string { } } +func (c *Client) logError(err error) { + // Logs the error with the same level/verbosity as the error + if c.Debug { + log.Logger.Inherit(err, fmt.Sprintf("\nwith stacktrace: %s\n", err.(*errors.Error).ErrorStack())) + } else { + log.Logger.Inherit(err, "") + } +} + +func (c *Client) isLetsConnect() bool { + // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php + return strings.HasPrefix(c.Name, "org.letsconnect-vpn.app") +} + // Client is the main struct for the VPN client. type Client struct { // The name of the client Name string `json:"-"` // The chosen server - Servers server.Servers `json:"servers"` + Servers server.List `json:"servers"` // The list of servers and organizations from disco Discovery discovery.Discovery `json:"discovery"` @@ -116,9 +110,6 @@ type Client struct { // The Failover monitor for the current VPN connection Failover *failover.DroppedConMon - - locationWg sync.WaitGroup - profileWg sync.WaitGroup } // New creates a new client with the following parameters: @@ -179,7 +170,7 @@ func New(name string, version string, directory string, stateCallback func(FSMSt // Registering means updating the FSM to get to the initial state correctly func (c *Client) Register() error { - if !c.InFSMState(StateDeregistered) { + if !c.FSM.InState(StateDeregistered) { return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current) } c.FSM.GoTransition(StateNoServer) @@ -200,27 +191,11 @@ func (c *Client) Deregister() { *c = Client{} } -// askProfile asks the user for a profile by moving the FSM to the ASK_PROFILE state. -func (c *Client) askProfile(srv server.Server) error { - ps, err := server.ValidProfiles(srv, c.SupportsWireguard) - if err != nil { - return err - } - - c.profileWg.Add(1) - if err = c.FSM.GoTransitionRequired(StateAskProfile, convertProfiles(*ps)); err != nil { - return err - } - c.profileWg.Wait() - - return nil -} - // DiscoOrganizations gets the organizations list from the discovery server // If the list cannot be retrieved an error is returned. // If this is the case then a previous version of the list is returned if there is any. // This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#organization-list. -func (c *Client) DiscoOrganizations() (orgs *discotypes.Organizations, err error) { +func (c *Client) DiscoOrganizations(ck *cookie.Cookie) (orgs *discotypes.Organizations, err error) { defer func() { if err != nil { c.logError(err) @@ -237,14 +212,15 @@ func (c *Client) DiscoOrganizations() (orgs *discotypes.Organizations, err error c.Discovery.MarkOrganizationsExpired() } - return c.Discovery.Organizations() + // TODO: pass a context + return c.Discovery.Organizations(ck.Context()) } // DiscoServers gets the servers list from the discovery server // If the list cannot be retrieved an error is returned. // If this is the case then a previous version of the list is returned if there is any. // This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#server-list. -func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { +func (c *Client) DiscoServers(ck *cookie.Cookie) (dss *discotypes.Servers, err error) { defer func() { if err != nil { c.logError(err) @@ -256,7 +232,8 @@ func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { return nil, errors.Errorf("discovery with Let's Connect is not supported") } - return c.Discovery.Servers() + // TODO: pass a context + return c.Discovery.Servers(ck.Context()) } // ExpiryTimes returns the different Unix timestamps regarding expiry @@ -266,7 +243,7 @@ func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { // These times are reset when the VPN gets disconnected func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { // Get current expiry time - srv, err := c.Servers.GetCurrentServer() + srv, err := c.Servers.Current() if err != nil { c.logError(err) return nil, err @@ -293,149 +270,488 @@ func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { }, nil } -func convertProfiles(profiles server.ProfileInfo) srvtypes.Profiles { - m := make(map[string]srvtypes.Profile) - for _, p := range profiles.Info.ProfileList { - var protocols []protocol.Protocol - // loop through all protocol strings - for _, ps := range p.VPNProtoList { - protocols = append(protocols, protocol.New(ps)) - } - m[p.ID] = srvtypes.Profile{ - DisplayName: map[string]string{ - "en": p.DisplayName, - }, - Protocols: protocols, +func (c *Client) locationCallback(ck *cookie.Cookie) error { + locs := c.Discovery.SecureLocationList() + errChan := make(chan error) + go func() { + err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{ + C: ck, + Data: locs, + }) + if err != nil { + errChan <- err } + }() + loc, err := ck.Receive(errChan) + if err != nil { + return err } - return srvtypes.Profiles{Map: m, Current: profiles.Current} + err = c.SetSecureLocation(ck, loc) + if err != nil { + return err + } + t := c.FSM.GoTransition(StateChosenLocation) + if !t { + log.Logger.Warningf("transition chosen location not completed") + } + return nil } -func convertGeneric(server server.InstituteAccessServer) (*srvtypes.Server, error) { - b, err := server.Base() +func (c *Client) loginCallback(ck *cookie.Cookie, srv server.Server) error { + url, err := server.OAuthURL(srv, c.Name) if err != nil { - return nil, err + return err } - return &srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: b.URL, - Profiles: convertProfiles(b.Profiles), - }, nil + err = c.FSM.GoTransitionRequired(StateOAuthStarted, url) + if err != nil { + return err + } + err = server.OAuthExchange(ck.Context(), srv) + if err != nil { + return err + } + return nil } -// TODO: CLEAN THIS UP -func (c *Client) ServerList() (*srvtypes.List, error) { - custom := c.Servers.CustomServers - var customServers []srvtypes.Server - for _, v := range custom.Map { - if v == nil { - return nil, errors.New("found nil value in custom server map") +func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool) error { + // location + if srv.NeedsLocation() { + err := c.locationCallback(ck) + if err != nil { + return err } - conv, err := convertGeneric(*v) + } + + t := c.FSM.GoTransition(StateChosenServer) + if !t { + log.Logger.Warningf("transition not completed for chosen server") + } + // oauth + // TODO: This should be ck.Context() + // But needsrelogin needs a rewrite to support this properly + if server.NeedsRelogin(context.Background(), srv) || forceauth { + err := c.loginCallback(ck, srv) if err != nil { - return nil, errors.Errorf("failed to convert custom server for public type: %v", err) + return err } - customServers = append(customServers, *conv) } - institute := c.Servers.InstituteServers - var instituteServers []srvtypes.Institute - for _, v := range institute.Map { - if v == nil { - return nil, errors.New("found nil value in institute server map") + t = c.FSM.GoTransition(StateAuthorized) + if !t { + log.Logger.Warningf("transition authorized not completed") + } + + return nil +} + +func (c *Client) profileCallback(ck *cookie.Cookie, srv server.Server) error { + vp, err := server.HasValidProfile(ck.Context(), srv, c.SupportsWireguard) + if err != nil { + return err + } + if !vp { + b, err := srv.Base() + if err != nil { + return err } - conv, err := convertGeneric(*v) + ps := b.Profiles.Public() + errChan := make(chan error) + go func() { + err := c.FSM.GoTransitionRequired(StateAskProfile, &srvtypes.RequiredAskTransition{ + C: ck, + Data: ps, + }) + if err != nil { + errChan <- err + } + }() + pID, err := ck.Receive(errChan) if err != nil { - return nil, errors.Errorf("failed to convert institute server for public type: %v", err) + return err + } + err = server.Profile(srv, pID) + if err != nil { + return err } - instituteServers = append(instituteServers, srvtypes.Institute{ - Server: *conv, - // TODO: delisted - Delisted: false, - }) } + t := c.FSM.GoTransition(StateChosenProfile) + if !t { + log.Logger.Warningf("transition chosen profile not completed") + } + return nil +} - var secureInternet *srvtypes.SecureInternet - if c.Servers.HasSecureInternet() { - b, err := c.Servers.SecureInternetHomeServer.Base() - if err == nil { - generic := srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: b.URL, - Profiles: convertProfiles(b.Profiles), - } - cc := c.Servers.SecureInternetHomeServer.CurrentLocation - secureInternet = &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, - // TODO: delisted - Delisted: false, - } +// AddServer adds a server with identifier and type +func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.Type, ni bool) (err error) { + // If we have failed to add the server, we remove it again + // We add the server because we can then obtain it in other callback functions + defer func() { + if err != nil { + _ = c.RemoveServer(identifier, _type) //nolint:errcheck } + c.FSM.GoTransition(StateNoServer) + }() + if !ni { + if !c.FSM.InState(StateNoServer) { + return errors.Errorf("wrong state to add a server: %s", GetStateName(c.FSM.Current)) + } + t := c.FSM.GoTransition(StateLoadingServer) + if !t { + log.Logger.Warningf("transition not completed for loading server") + } } - return &srvtypes.List{ - Institutes: instituteServers, - SecureInternet: secureInternet, - Custom: customServers, - }, nil -} -// TODO: CLEAN THIS UP -func (c *Client) CurrentServer() (*srvtypes.Current, error) { - srvs := c.Servers + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return err + } + + var srv server.Server - switch srvs.IsType { - case server.InstituteAccessServerType: - curr, err := srvs.GetInstituteAccess(srvs.InstituteServers.CurrentURL) + switch _type { + case srvtypes.TypeInstituteAccess: + dSrv, err := c.Discovery.ServerByURL(identifier, "institute_access") if err != nil { - return nil, err + return err } - conv, err := convertGeneric(*curr) + srv, err = c.Servers.AddInstituteAccess(ck.Context(), dSrv) if err != nil { - return nil, err + return err } - return &srvtypes.Current{ - Institute: &srvtypes.Institute{ - Server: *conv, - // TODO: delisted - Delisted: false, - }, - Type: srvtypes.TypeInstituteAccess, - }, nil - case server.CustomServerType: - curr, err := srvs.GetCustomServer(srvs.CustomServers.CurrentURL) + case srvtypes.TypeSecureInternet: + dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier) if err != nil { - return nil, err + return err } - conv, err := convertGeneric(*curr) + srv, err = c.Servers.AddSecureInternet(ck.Context(), dOrg, dSrv) if err != nil { - return nil, err + return err + } + case srvtypes.TypeCustom: + srv, err = c.Servers.AddCustom(ck.Context(), identifier) + if err != nil { + return err + } + default: + return errors.Errorf("not a valid server type: %v", _type) + } + + // if we are non interactive, we run no callbacks + if ni { + return nil + } + + // callbacks + return c.callbacks(ck, srv, false) +} + +func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool) (cfg *srvtypes.Configuration, err error) { + // do the callbacks to ensure valid profile, location and authorization + err = c.callbacks(ck, srv, forceAuth) + if err != nil { + return nil, err + } + + t := c.FSM.GoTransition(StateRequestConfig) + if !t { + log.Logger.Warningf("transition not completed for requesting config") + } + + err = c.profileCallback(ck, srv) + if err != nil { + return nil, err + } + + cfgS, err := server.Config(ck.Context(), srv, c.SupportsWireguard, pTCP) + if err != nil { + return nil, err + } + p, err := server.CurrentProfile(srv) + if err != nil { + return nil, err + } + pcfg := cfgS.Public(p.DefaultGateway) + if err != nil { + return nil, err + } + return &pcfg, nil +} + +func (c *Client) server(identifier string, _type srvtypes.Type) (srv server.Server, setter func(server.Server) error, err error) { + switch _type { + case srvtypes.TypeInstituteAccess: + srv, err = c.Servers.InstituteAccess(identifier) + setter = c.Servers.SetInstituteAccess + case srvtypes.TypeSecureInternet: + srv, err = c.Servers.SecureInternet(identifier) + setter = c.Servers.SetSecureInternet + case srvtypes.TypeCustom: + srv, err = c.Servers.CustomServer(identifier) + setter = c.Servers.SetCustom + default: + return nil, nil, errors.Errorf("not a valid server type: %v", _type) + } + return srv, setter, err +} + +// GetConfig gets a VPN configuration +func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool) (cfg *srvtypes.Configuration, err error) { + defer func() { + if err == nil { + c.FSM.GoTransition(StateGotConfig) + } else { + // go back if an error occurred + c.FSM.GoTransition(StateNoServer) + } + }() + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return nil, err + } + t := c.FSM.GoTransition(StateLoadingServer) + if !t { + log.Logger.Warningf("transition not completed for loading server") + } + srv, set, err := c.server(identifier, _type) + if err != nil { + return nil, err + } + // refresh the server endpoints + err = server.RefreshEndpoints(ck.Context(), srv) + if err != nil { + log.Logger.Warningf("failed to refresh server endpoints: %v", err) + } + + // get a config and retry with authorization if expired + cfg, err = c.config(ck, srv, pTCP, false) + tErr := &oauth.TokensInvalidError{} + if err != nil && errors.As(err, &tErr) { + cfg, err = c.config(ck, srv, pTCP, true) + } + + // still an error, return nil with the error + if err != nil { + return nil, err + } + + // set the current server + if err = set(srv); err != nil { + return nil, err + } + + return cfg, nil +} + +func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) { + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return err + } + switch _type { + case srvtypes.TypeInstituteAccess: + return c.Servers.RemoveInstituteAccess(identifier) + case srvtypes.TypeSecureInternet: + return c.Servers.RemoveSecureInternet(identifier) + case srvtypes.TypeCustom: + return c.Servers.RemoveCustom(identifier) + default: + return errors.Errorf("not a valid server type: %v", _type) + } +} + +func (c *Client) CurrentServer() (*srvtypes.Current, error) { + if !c.FSM.InState(StateGotConfig) { + return nil, errors.Errorf("State: %s, cannot have a current server. Did you get a VPN configuration?", GetStateName(c.FSM.Current)) + } + srv, err := c.Servers.Current() + if err != nil { + return nil, err + } + return c.pubCurrentServer(srv) +} + +func (c *Client) pubCurrentServer(srv server.Server) (*srvtypes.Current, error) { + b, err := srv.Base() + if err != nil { + return nil, err + } + pub, err := srv.Public() + if err != nil { + return nil, err + } + switch t := pub.(type) { + case *srvtypes.Server: + if b.Type == srvtypes.TypeInstituteAccess { + return &srvtypes.Current{ + Institute: &srvtypes.Institute{ + Server: *t, + // TODO: delisted + Delisted: false, + }, + Type: srvtypes.TypeInstituteAccess, + }, nil } return &srvtypes.Current{ - Custom: conv, + Custom: t, Type: srvtypes.TypeCustom, }, nil - case server.SecureInternetServerType: - b, err := c.Servers.SecureInternetHomeServer.Base() + case *srvtypes.SecureInternet: + t.Locations = c.Discovery.SecureLocationList() + return &srvtypes.Current{ + SecureInternet: t, + Type: srvtypes.TypeSecureInternet, + }, nil + default: + panic("unknown type") + } +} + +// TODO: This should not rely on interface{} +func (c *Client) pubServer(srv server.Server) (interface{}, error) { + pub, err := srv.Public() + if err != nil { + return nil, err + } + switch t := pub.(type) { + case *srvtypes.Server: + b, err := srv.Base() if err != nil { return nil, err } - generic := srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: c.Servers.SecureInternetHomeServer.HomeOrganizationID, - Profiles: convertProfiles(b.Profiles), - } - cc := c.Servers.SecureInternetHomeServer.CurrentLocation - return &srvtypes.Current{ - SecureInternet: &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, + if b.Type == srvtypes.TypeInstituteAccess { + return &srvtypes.Institute{ + Server: *t, // TODO: delisted Delisted: false, - }, - Type: srvtypes.TypeSecureInternet, - }, nil + }, nil + } + return t, nil + case *srvtypes.SecureInternet: + t.Locations = c.Discovery.SecureLocationList() + return t, nil default: - return nil, errors.New("current server not found") + panic("unknown type") + } +} + +func (c *Client) ServerList() (*srvtypes.List, error) { + if c.FSM.InState(StateDeregistered) { + return nil, errors.New("client is not registered") + } + var customServers []srvtypes.Server + for _, v := range c.Servers.CustomServers.Map { + if v == nil { + continue + } + p, err := c.pubServer(v) + if err != nil { + continue + } + c, ok := p.(*srvtypes.Server) + if !ok { + continue + } + customServers = append(customServers, *c) } + var instituteServers []srvtypes.Institute + for _, v := range c.Servers.CustomServers.Map { + if v == nil { + continue + } + p, err := c.pubServer(v) + if err != nil { + continue + } + i, ok := p.(*srvtypes.Institute) + if !ok { + continue + } + instituteServers = append(instituteServers, *i) + } + var secureInternet *srvtypes.SecureInternet + if c.Servers.HasSecureInternet() { + srv := &c.Servers.SecureInternetHomeServer + p, err := c.pubServer(srv) + if err == nil { + s, ok := p.(*srvtypes.SecureInternet) + if ok { + secureInternet = s + } + } + } + return &srvtypes.List{ + Institutes: instituteServers, + SecureInternet: secureInternet, + Custom: customServers, + }, nil +} + +func (c *Client) SetProfileID(pID string) (err error) { + srv, err := c.Servers.Current() + if err != nil { + return err + } + return server.Profile(srv, pID) +} + +func (c *Client) Cleanup(ck *cookie.Cookie) (err error) { + // get the current server + srv, err := c.Servers.Current() + if err != nil { + return err + } + // TODO: Support cookie context here + // if server.NeedsRelogin(context.Background(), srv) { + // // TODO: ask client for tokens + // } + err = server.Disconnect(ck.Context(), srv) + if err != nil { + return err + } + // TODO: Set tokens with callback + return nil +} + +func (c *Client) SetSecureLocation(ck *cookie.Cookie, countryCode string) (err error) { + if c.isLetsConnect() { + return errors.Errorf("setting a secure internet location with Let's Connect! is not supported") + } + + if !c.Servers.HasSecureInternet() { + return errors.Errorf("no secure internet server available to set a location for") + } + + dSrv, err := c.Discovery.ServerByCountryCode(countryCode) + if err != nil { + return err + } + + return c.Servers.SecureInternetHomeServer.Location(ck.Context(), dSrv) +} + +func (c *Client) RenewSession(ck *cookie.Cookie) (err error) { + srv, err := c.Servers.Current() + if err != nil { + return err + } + // The server has not been chosen yet, this means that we want to manually renew + // TODO: is this needed? + if !c.FSM.InState(StateChosenServer) { + c.FSM.GoTransition(StateLoadingServer) + c.FSM.GoTransition(StateChosenServer) + } + // TODO: Maybe this can be deleted because we force auth now + server.MarkTokensForRenew(srv) + // run the callbacks by forcing auth + return c.callbacks(ck, srv, true) +} + +func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, wgMTU int, readRxBytes func() (int64, error)) (bool, error) { + if c.Failover != nil { + return false, errors.New("another failover process is already started") + } + + c.Failover = failover.New(readRxBytes) + + return c.Failover.Start(ck.Context(), gateway, wgMTU) } |
