diff options
Diffstat (limited to 'client/client.go')
| -rw-r--r-- | client/client.go | 622 |
1 files changed, 469 insertions, 153 deletions
diff --git a/client/client.go b/client/client.go index 70adb71..813f6dc 100644 --- a/client/client.go +++ b/client/client.go @@ -2,9 +2,9 @@ package client import ( + "context" "fmt" "strings" - "sync" "github.com/eduvpn/eduvpn-common/internal/config" "github.com/eduvpn/eduvpn-common/internal/discovery" @@ -14,32 +14,12 @@ import ( "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" + "github.com/eduvpn/eduvpn-common/types/cookie" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - "github.com/eduvpn/eduvpn-common/types/protocol" srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" ) -type ( - // ServerBase is an alias to the internal ServerBase - // This contains the details for each server. - ServerBase = server.Base -) - -func (c *Client) logError(err error) { - // Logs the error with the same level/verbosity as the error - if c.Debug { - log.Logger.Inherit(err, fmt.Sprintf("\nwith stacktrace: %s\n", err.(*errors.Error).ErrorStack())) - } else { - log.Logger.Inherit(err, "") - } -} - -func (c *Client) isLetsConnect() bool { - // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php - return strings.HasPrefix(c.Name, "org.letsconnect-vpn.app") -} - // isAllowedClientID checks if the 'clientID' is in the list of allowed client IDs func isAllowedClientID(clientID string) bool { allowList := []string{ @@ -91,13 +71,27 @@ func userAgentName(clientID string) string { } } +func (c *Client) logError(err error) { + // Logs the error with the same level/verbosity as the error + if c.Debug { + log.Logger.Inherit(err, fmt.Sprintf("\nwith stacktrace: %s\n", err.(*errors.Error).ErrorStack())) + } else { + log.Logger.Inherit(err, "") + } +} + +func (c *Client) isLetsConnect() bool { + // see https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php + return strings.HasPrefix(c.Name, "org.letsconnect-vpn.app") +} + // Client is the main struct for the VPN client. type Client struct { // The name of the client Name string `json:"-"` // The chosen server - Servers server.Servers `json:"servers"` + Servers server.List `json:"servers"` // The list of servers and organizations from disco Discovery discovery.Discovery `json:"discovery"` @@ -116,9 +110,6 @@ type Client struct { // The Failover monitor for the current VPN connection Failover *failover.DroppedConMon - - locationWg sync.WaitGroup - profileWg sync.WaitGroup } // New creates a new client with the following parameters: @@ -179,7 +170,7 @@ func New(name string, version string, directory string, stateCallback func(FSMSt // Registering means updating the FSM to get to the initial state correctly func (c *Client) Register() error { - if !c.InFSMState(StateDeregistered) { + if !c.FSM.InState(StateDeregistered) { return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current) } c.FSM.GoTransition(StateNoServer) @@ -200,27 +191,11 @@ func (c *Client) Deregister() { *c = Client{} } -// askProfile asks the user for a profile by moving the FSM to the ASK_PROFILE state. -func (c *Client) askProfile(srv server.Server) error { - ps, err := server.ValidProfiles(srv, c.SupportsWireguard) - if err != nil { - return err - } - - c.profileWg.Add(1) - if err = c.FSM.GoTransitionRequired(StateAskProfile, convertProfiles(*ps)); err != nil { - return err - } - c.profileWg.Wait() - - return nil -} - // DiscoOrganizations gets the organizations list from the discovery server // If the list cannot be retrieved an error is returned. // If this is the case then a previous version of the list is returned if there is any. // This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#organization-list. -func (c *Client) DiscoOrganizations() (orgs *discotypes.Organizations, err error) { +func (c *Client) DiscoOrganizations(ck *cookie.Cookie) (orgs *discotypes.Organizations, err error) { defer func() { if err != nil { c.logError(err) @@ -237,14 +212,15 @@ func (c *Client) DiscoOrganizations() (orgs *discotypes.Organizations, err error c.Discovery.MarkOrganizationsExpired() } - return c.Discovery.Organizations() + // TODO: pass a context + return c.Discovery.Organizations(ck.Context()) } // DiscoServers gets the servers list from the discovery server // If the list cannot be retrieved an error is returned. // If this is the case then a previous version of the list is returned if there is any. // This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#server-list. -func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { +func (c *Client) DiscoServers(ck *cookie.Cookie) (dss *discotypes.Servers, err error) { defer func() { if err != nil { c.logError(err) @@ -256,7 +232,8 @@ func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { return nil, errors.Errorf("discovery with Let's Connect is not supported") } - return c.Discovery.Servers() + // TODO: pass a context + return c.Discovery.Servers(ck.Context()) } // ExpiryTimes returns the different Unix timestamps regarding expiry @@ -266,7 +243,7 @@ func (c *Client) DiscoServers() (dss *discotypes.Servers, err error) { // These times are reset when the VPN gets disconnected func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { // Get current expiry time - srv, err := c.Servers.GetCurrentServer() + srv, err := c.Servers.Current() if err != nil { c.logError(err) return nil, err @@ -293,149 +270,488 @@ func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { }, nil } -func convertProfiles(profiles server.ProfileInfo) srvtypes.Profiles { - m := make(map[string]srvtypes.Profile) - for _, p := range profiles.Info.ProfileList { - var protocols []protocol.Protocol - // loop through all protocol strings - for _, ps := range p.VPNProtoList { - protocols = append(protocols, protocol.New(ps)) - } - m[p.ID] = srvtypes.Profile{ - DisplayName: map[string]string{ - "en": p.DisplayName, - }, - Protocols: protocols, +func (c *Client) locationCallback(ck *cookie.Cookie) error { + locs := c.Discovery.SecureLocationList() + errChan := make(chan error) + go func() { + err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{ + C: ck, + Data: locs, + }) + if err != nil { + errChan <- err } + }() + loc, err := ck.Receive(errChan) + if err != nil { + return err } - return srvtypes.Profiles{Map: m, Current: profiles.Current} + err = c.SetSecureLocation(ck, loc) + if err != nil { + return err + } + t := c.FSM.GoTransition(StateChosenLocation) + if !t { + log.Logger.Warningf("transition chosen location not completed") + } + return nil } -func convertGeneric(server server.InstituteAccessServer) (*srvtypes.Server, error) { - b, err := server.Base() +func (c *Client) loginCallback(ck *cookie.Cookie, srv server.Server) error { + url, err := server.OAuthURL(srv, c.Name) if err != nil { - return nil, err + return err } - return &srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: b.URL, - Profiles: convertProfiles(b.Profiles), - }, nil + err = c.FSM.GoTransitionRequired(StateOAuthStarted, url) + if err != nil { + return err + } + err = server.OAuthExchange(ck.Context(), srv) + if err != nil { + return err + } + return nil } -// TODO: CLEAN THIS UP -func (c *Client) ServerList() (*srvtypes.List, error) { - custom := c.Servers.CustomServers - var customServers []srvtypes.Server - for _, v := range custom.Map { - if v == nil { - return nil, errors.New("found nil value in custom server map") +func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool) error { + // location + if srv.NeedsLocation() { + err := c.locationCallback(ck) + if err != nil { + return err } - conv, err := convertGeneric(*v) + } + + t := c.FSM.GoTransition(StateChosenServer) + if !t { + log.Logger.Warningf("transition not completed for chosen server") + } + // oauth + // TODO: This should be ck.Context() + // But needsrelogin needs a rewrite to support this properly + if server.NeedsRelogin(context.Background(), srv) || forceauth { + err := c.loginCallback(ck, srv) if err != nil { - return nil, errors.Errorf("failed to convert custom server for public type: %v", err) + return err } - customServers = append(customServers, *conv) } - institute := c.Servers.InstituteServers - var instituteServers []srvtypes.Institute - for _, v := range institute.Map { - if v == nil { - return nil, errors.New("found nil value in institute server map") + t = c.FSM.GoTransition(StateAuthorized) + if !t { + log.Logger.Warningf("transition authorized not completed") + } + + return nil +} + +func (c *Client) profileCallback(ck *cookie.Cookie, srv server.Server) error { + vp, err := server.HasValidProfile(ck.Context(), srv, c.SupportsWireguard) + if err != nil { + return err + } + if !vp { + b, err := srv.Base() + if err != nil { + return err } - conv, err := convertGeneric(*v) + ps := b.Profiles.Public() + errChan := make(chan error) + go func() { + err := c.FSM.GoTransitionRequired(StateAskProfile, &srvtypes.RequiredAskTransition{ + C: ck, + Data: ps, + }) + if err != nil { + errChan <- err + } + }() + pID, err := ck.Receive(errChan) if err != nil { - return nil, errors.Errorf("failed to convert institute server for public type: %v", err) + return err + } + err = server.Profile(srv, pID) + if err != nil { + return err } - instituteServers = append(instituteServers, srvtypes.Institute{ - Server: *conv, - // TODO: delisted - Delisted: false, - }) } + t := c.FSM.GoTransition(StateChosenProfile) + if !t { + log.Logger.Warningf("transition chosen profile not completed") + } + return nil +} - var secureInternet *srvtypes.SecureInternet - if c.Servers.HasSecureInternet() { - b, err := c.Servers.SecureInternetHomeServer.Base() - if err == nil { - generic := srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: b.URL, - Profiles: convertProfiles(b.Profiles), - } - cc := c.Servers.SecureInternetHomeServer.CurrentLocation - secureInternet = &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, - // TODO: delisted - Delisted: false, - } +// AddServer adds a server with identifier and type +func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.Type, ni bool) (err error) { + // If we have failed to add the server, we remove it again + // We add the server because we can then obtain it in other callback functions + defer func() { + if err != nil { + _ = c.RemoveServer(identifier, _type) //nolint:errcheck } + c.FSM.GoTransition(StateNoServer) + }() + if !ni { + if !c.FSM.InState(StateNoServer) { + return errors.Errorf("wrong state to add a server: %s", GetStateName(c.FSM.Current)) + } + t := c.FSM.GoTransition(StateLoadingServer) + if !t { + log.Logger.Warningf("transition not completed for loading server") + } } - return &srvtypes.List{ - Institutes: instituteServers, - SecureInternet: secureInternet, - Custom: customServers, - }, nil -} -// TODO: CLEAN THIS UP -func (c *Client) CurrentServer() (*srvtypes.Current, error) { - srvs := c.Servers + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return err + } + + var srv server.Server - switch srvs.IsType { - case server.InstituteAccessServerType: - curr, err := srvs.GetInstituteAccess(srvs.InstituteServers.CurrentURL) + switch _type { + case srvtypes.TypeInstituteAccess: + dSrv, err := c.Discovery.ServerByURL(identifier, "institute_access") if err != nil { - return nil, err + return err } - conv, err := convertGeneric(*curr) + srv, err = c.Servers.AddInstituteAccess(ck.Context(), dSrv) if err != nil { - return nil, err + return err } - return &srvtypes.Current{ - Institute: &srvtypes.Institute{ - Server: *conv, - // TODO: delisted - Delisted: false, - }, - Type: srvtypes.TypeInstituteAccess, - }, nil - case server.CustomServerType: - curr, err := srvs.GetCustomServer(srvs.CustomServers.CurrentURL) + case srvtypes.TypeSecureInternet: + dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier) if err != nil { - return nil, err + return err } - conv, err := convertGeneric(*curr) + srv, err = c.Servers.AddSecureInternet(ck.Context(), dOrg, dSrv) if err != nil { - return nil, err + return err + } + case srvtypes.TypeCustom: + srv, err = c.Servers.AddCustom(ck.Context(), identifier) + if err != nil { + return err + } + default: + return errors.Errorf("not a valid server type: %v", _type) + } + + // if we are non interactive, we run no callbacks + if ni { + return nil + } + + // callbacks + return c.callbacks(ck, srv, false) +} + +func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool) (cfg *srvtypes.Configuration, err error) { + // do the callbacks to ensure valid profile, location and authorization + err = c.callbacks(ck, srv, forceAuth) + if err != nil { + return nil, err + } + + t := c.FSM.GoTransition(StateRequestConfig) + if !t { + log.Logger.Warningf("transition not completed for requesting config") + } + + err = c.profileCallback(ck, srv) + if err != nil { + return nil, err + } + + cfgS, err := server.Config(ck.Context(), srv, c.SupportsWireguard, pTCP) + if err != nil { + return nil, err + } + p, err := server.CurrentProfile(srv) + if err != nil { + return nil, err + } + pcfg := cfgS.Public(p.DefaultGateway) + if err != nil { + return nil, err + } + return &pcfg, nil +} + +func (c *Client) server(identifier string, _type srvtypes.Type) (srv server.Server, setter func(server.Server) error, err error) { + switch _type { + case srvtypes.TypeInstituteAccess: + srv, err = c.Servers.InstituteAccess(identifier) + setter = c.Servers.SetInstituteAccess + case srvtypes.TypeSecureInternet: + srv, err = c.Servers.SecureInternet(identifier) + setter = c.Servers.SetSecureInternet + case srvtypes.TypeCustom: + srv, err = c.Servers.CustomServer(identifier) + setter = c.Servers.SetCustom + default: + return nil, nil, errors.Errorf("not a valid server type: %v", _type) + } + return srv, setter, err +} + +// GetConfig gets a VPN configuration +func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool) (cfg *srvtypes.Configuration, err error) { + defer func() { + if err == nil { + c.FSM.GoTransition(StateGotConfig) + } else { + // go back if an error occurred + c.FSM.GoTransition(StateNoServer) + } + }() + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return nil, err + } + t := c.FSM.GoTransition(StateLoadingServer) + if !t { + log.Logger.Warningf("transition not completed for loading server") + } + srv, set, err := c.server(identifier, _type) + if err != nil { + return nil, err + } + // refresh the server endpoints + err = server.RefreshEndpoints(ck.Context(), srv) + if err != nil { + log.Logger.Warningf("failed to refresh server endpoints: %v", err) + } + + // get a config and retry with authorization if expired + cfg, err = c.config(ck, srv, pTCP, false) + tErr := &oauth.TokensInvalidError{} + if err != nil && errors.As(err, &tErr) { + cfg, err = c.config(ck, srv, pTCP, true) + } + + // still an error, return nil with the error + if err != nil { + return nil, err + } + + // set the current server + if err = set(srv); err != nil { + return nil, err + } + + return cfg, nil +} + +func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error) { + identifier, err = http.EnsureValidURL(identifier, _type != srvtypes.TypeSecureInternet) + if err != nil { + return err + } + switch _type { + case srvtypes.TypeInstituteAccess: + return c.Servers.RemoveInstituteAccess(identifier) + case srvtypes.TypeSecureInternet: + return c.Servers.RemoveSecureInternet(identifier) + case srvtypes.TypeCustom: + return c.Servers.RemoveCustom(identifier) + default: + return errors.Errorf("not a valid server type: %v", _type) + } +} + +func (c *Client) CurrentServer() (*srvtypes.Current, error) { + if !c.FSM.InState(StateGotConfig) { + return nil, errors.Errorf("State: %s, cannot have a current server. Did you get a VPN configuration?", GetStateName(c.FSM.Current)) + } + srv, err := c.Servers.Current() + if err != nil { + return nil, err + } + return c.pubCurrentServer(srv) +} + +func (c *Client) pubCurrentServer(srv server.Server) (*srvtypes.Current, error) { + b, err := srv.Base() + if err != nil { + return nil, err + } + pub, err := srv.Public() + if err != nil { + return nil, err + } + switch t := pub.(type) { + case *srvtypes.Server: + if b.Type == srvtypes.TypeInstituteAccess { + return &srvtypes.Current{ + Institute: &srvtypes.Institute{ + Server: *t, + // TODO: delisted + Delisted: false, + }, + Type: srvtypes.TypeInstituteAccess, + }, nil } return &srvtypes.Current{ - Custom: conv, + Custom: t, Type: srvtypes.TypeCustom, }, nil - case server.SecureInternetServerType: - b, err := c.Servers.SecureInternetHomeServer.Base() + case *srvtypes.SecureInternet: + t.Locations = c.Discovery.SecureLocationList() + return &srvtypes.Current{ + SecureInternet: t, + Type: srvtypes.TypeSecureInternet, + }, nil + default: + panic("unknown type") + } +} + +// TODO: This should not rely on interface{} +func (c *Client) pubServer(srv server.Server) (interface{}, error) { + pub, err := srv.Public() + if err != nil { + return nil, err + } + switch t := pub.(type) { + case *srvtypes.Server: + b, err := srv.Base() if err != nil { return nil, err } - generic := srvtypes.Server{ - DisplayName: b.DisplayName, - Identifier: c.Servers.SecureInternetHomeServer.HomeOrganizationID, - Profiles: convertProfiles(b.Profiles), - } - cc := c.Servers.SecureInternetHomeServer.CurrentLocation - return &srvtypes.Current{ - SecureInternet: &srvtypes.SecureInternet{ - Server: generic, - CountryCode: cc, + if b.Type == srvtypes.TypeInstituteAccess { + return &srvtypes.Institute{ + Server: *t, // TODO: delisted Delisted: false, - }, - Type: srvtypes.TypeSecureInternet, - }, nil + }, nil + } + return t, nil + case *srvtypes.SecureInternet: + t.Locations = c.Discovery.SecureLocationList() + return t, nil default: - return nil, errors.New("current server not found") + panic("unknown type") + } +} + +func (c *Client) ServerList() (*srvtypes.List, error) { + if c.FSM.InState(StateDeregistered) { + return nil, errors.New("client is not registered") + } + var customServers []srvtypes.Server + for _, v := range c.Servers.CustomServers.Map { + if v == nil { + continue + } + p, err := c.pubServer(v) + if err != nil { + continue + } + c, ok := p.(*srvtypes.Server) + if !ok { + continue + } + customServers = append(customServers, *c) } + var instituteServers []srvtypes.Institute + for _, v := range c.Servers.CustomServers.Map { + if v == nil { + continue + } + p, err := c.pubServer(v) + if err != nil { + continue + } + i, ok := p.(*srvtypes.Institute) + if !ok { + continue + } + instituteServers = append(instituteServers, *i) + } + var secureInternet *srvtypes.SecureInternet + if c.Servers.HasSecureInternet() { + srv := &c.Servers.SecureInternetHomeServer + p, err := c.pubServer(srv) + if err == nil { + s, ok := p.(*srvtypes.SecureInternet) + if ok { + secureInternet = s + } + } + } + return &srvtypes.List{ + Institutes: instituteServers, + SecureInternet: secureInternet, + Custom: customServers, + }, nil +} + +func (c *Client) SetProfileID(pID string) (err error) { + srv, err := c.Servers.Current() + if err != nil { + return err + } + return server.Profile(srv, pID) +} + +func (c *Client) Cleanup(ck *cookie.Cookie) (err error) { + // get the current server + srv, err := c.Servers.Current() + if err != nil { + return err + } + // TODO: Support cookie context here + // if server.NeedsRelogin(context.Background(), srv) { + // // TODO: ask client for tokens + // } + err = server.Disconnect(ck.Context(), srv) + if err != nil { + return err + } + // TODO: Set tokens with callback + return nil +} + +func (c *Client) SetSecureLocation(ck *cookie.Cookie, countryCode string) (err error) { + if c.isLetsConnect() { + return errors.Errorf("setting a secure internet location with Let's Connect! is not supported") + } + + if !c.Servers.HasSecureInternet() { + return errors.Errorf("no secure internet server available to set a location for") + } + + dSrv, err := c.Discovery.ServerByCountryCode(countryCode) + if err != nil { + return err + } + + return c.Servers.SecureInternetHomeServer.Location(ck.Context(), dSrv) +} + +func (c *Client) RenewSession(ck *cookie.Cookie) (err error) { + srv, err := c.Servers.Current() + if err != nil { + return err + } + // The server has not been chosen yet, this means that we want to manually renew + // TODO: is this needed? + if !c.FSM.InState(StateChosenServer) { + c.FSM.GoTransition(StateLoadingServer) + c.FSM.GoTransition(StateChosenServer) + } + // TODO: Maybe this can be deleted because we force auth now + server.MarkTokensForRenew(srv) + // run the callbacks by forcing auth + return c.callbacks(ck, srv, true) +} + +func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, wgMTU int, readRxBytes func() (int64, error)) (bool, error) { + if c.Failover != nil { + return false, errors.New("another failover process is already started") + } + + c.Failover = failover.New(readRxBytes) + + return c.Failover.Start(ck.Context(), gateway, wgMTU) } |
