diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-10-11 10:19:56 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-10-11 12:42:11 +0200 |
| commit | 565237c14a303a46d62d240b35c6f0082424256a (patch) | |
| tree | 522f2aeb441a3eb22b6d5e05e66ef348241b2e66 /client.go | |
| parent | 17e261dd224bc67f031b80930490768ea54353db (diff) | |
Client: Refactor out adding a Server from getting a config
Diffstat (limited to 'client.go')
| -rw-r--r-- | client.go | 306 |
1 files changed, 179 insertions, 127 deletions
@@ -11,8 +11,8 @@ import ( "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" - "github.com/eduvpn/eduvpn-common/types" "github.com/eduvpn/eduvpn-common/internal/util" + "github.com/eduvpn/eduvpn-common/types" ) type ( @@ -134,7 +134,7 @@ func (client *Client) Deregister() { client.Logger.Info( fmt.Sprintf( "Failed saving configuration, error: %s", - GetErrorTraceback(saveErr), + types.GetErrorTraceback(saveErr), ), ) } @@ -150,7 +150,7 @@ func (client *Client) goBackInternal() { client.Logger.Info( fmt.Sprintf( "Failed going back, error: %s", - GetErrorTraceback(goBackErr), + types.GetErrorTraceback(goBackErr), ), ) } @@ -179,7 +179,7 @@ func (client *Client) ensureLogin(chosenServer server.Server) error { // Relogin with oauth // This moves the state to authorized if server.NeedsRelogin(chosenServer) { - url, urlErr := server.GetOAuthURL(chosenServer, client.FSM.Name) + url, urlErr := server.GetOAuthURL(chosenServer, client.Name) client.FSM.GoTransitionWithData(STATE_OAUTH_STARTED, url, true) @@ -277,7 +277,7 @@ func (client *Client) getConfig( config, configType, configErr := client.retryConfigAuth(chosenServer, preferTCP) if configErr != nil { - return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} + return "", "", &types.WrappedErrorMessage{Level: types.GetErrorLevel(configErr), Message: errorMessage, Err: configErr} } currentServer, currentServerErr := client.Servers.GetCurrentServer() @@ -294,7 +294,7 @@ func (client *Client) getConfig( client.Logger.Info( fmt.Sprintf( "Failed saving configuration after getting a server: %s", - GetErrorTraceback(saveErr), + types.GetErrorTraceback(saveErr), ), ) } @@ -318,7 +318,7 @@ func (client *Client) SetSecureLocation(countryCode string) error { fmt.Sprintf( "Failed getting secure internet server by country code: %s with error: %s", countryCode, - GetErrorTraceback(serverErr), + types.GetErrorTraceback(serverErr), ), ) client.goBackInternal() @@ -330,7 +330,7 @@ func (client *Client) SetSecureLocation(countryCode string) error { client.Logger.Error( fmt.Sprintf( "Failed setting secure internet server with error: %s", - GetErrorTraceback(setLocationErr), + types.GetErrorTraceback(setLocationErr), ), ) client.goBackInternal() @@ -367,42 +367,6 @@ func (client *Client) askSecureLocation() error { return nil } -// addSecureInternetHomeServer adds a Secure Internet Home Server with `orgID` that was obtained from the Discovery file. -// Because there is only one Secure Internet Home Server, it replaces the existing one. -func (client *Client) addSecureInternetHomeServer(orgID string) (server.Server, error) { - errorMessage := fmt.Sprintf( - "failed adding Secure Internet home server with organization ID %s", - orgID, - ) - // Get the secure internet URL from discovery - secureOrg, secureServer, discoErr := client.Discovery.GetSecureHomeArgs(orgID) - if discoErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr} - } - - // Add the secure internet server - server, serverErr := client.Servers.AddSecureInternet(secureOrg, secureServer) - - if serverErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} - } - - var locationErr error - - if !client.Servers.HasSecureLocation() { - locationErr = client.askSecureLocation() - } else { - // reinitialize - locationErr = client.SetSecureLocation(client.Servers.GetSecureLocation()) - } - - if locationErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: locationErr} - } - - return server, nil -} - // RemoveSecureInternet removes the current secure internet server. // It returns an error if the server cannot be removed due to the state being DEREGISTERED. // Note that if the server does not exist, it returns nil as an error. @@ -423,7 +387,7 @@ func (client *Client) RemoveSecureInternet() error { client.Logger.Info( fmt.Sprintf( "Failed saving configuration after removing a secure internet server: %s", - GetErrorTraceback(saveErr), + types.GetErrorTraceback(saveErr), ), ) } @@ -449,7 +413,7 @@ func (client *Client) RemoveInstituteAccess(url string) error { client.Logger.Info( fmt.Sprintf( "Failed saving configuration after removing an institute access server: %s", - GetErrorTraceback(saveErr), + types.GetErrorTraceback(saveErr), ), ) } @@ -475,79 +439,108 @@ func (client *Client) RemoveCustomServer(url string) error { client.Logger.Info( fmt.Sprintf( "Failed saving configuration after removing a custom server: %s", - GetErrorTraceback(saveErr), + types.GetErrorTraceback(saveErr), ), ) } return nil } -// GetConfigSecureInternet gets a configuration for a Secure Internet Server. -// It ensures that the Secure Internet Server exists by creating or using an existing one with the orgID. -// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (client *Client) GetConfigSecureInternet( - orgID string, - preferTCP bool, -) (string, string, error) { - errorMessage := fmt.Sprintf( - "failed getting a configuration for Secure Internet organization %s", - orgID, - ) +// AddInstituteServer adds an Institute Access server by `url`. +func (client *Client) AddInstituteServer(url string) (server.Server, error) { + errorMessage := fmt.Sprintf("failed adding Institute Access server with url %s", url) // Not supported with Let's Connect! if client.isLetsConnect() { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: LetsConnectNotSupportedError{}} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: LetsConnectNotSupportedError{}} } + // Indicate that we're loading the server client.FSM.GoTransition(STATE_LOADING_SERVER) - server, serverErr := client.addSecureInternetHomeServer(orgID) + + // FIXME: Do nothing with discovery here as the client already has it + // So pass a server as the parameter + instituteServer, discoErr := client.Discovery.GetServerByURL(url, "institute_access") + if discoErr != nil { + client.goBackInternal() + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr} + } + + // Add the secure internet server + server, serverErr := client.Servers.AddInstituteAccessServer(instituteServer) if serverErr != nil { - client.Logger.Error( - fmt.Sprintf( - "Failed adding a secure internet server with error: %s", - GetErrorTraceback(serverErr), - ), - ) client.goBackInternal() - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } + // Indicate that we want to authorize this server client.FSM.GoTransition(STATE_CHOSEN_SERVER) - config, configType, configErr := client.getConfig(server, preferTCP) - if configErr != nil { - client.Logger.Inherit( - configErr, - fmt.Sprintf( - "Failed getting a secure internet configuration with error: %s", - GetErrorTraceback(configErr), - ), - ) - return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} + // Authorize it + loginErr := client.ensureLogin(server) + if loginErr != nil { + // Removing is best effort + _ = client.RemoveInstituteAccess(url) + return nil, &types.WrappedErrorMessage{Level: types.GetErrorLevel(loginErr), Message: errorMessage, Err: loginErr} } - return config, configType, nil + + client.FSM.GoTransitionWithData(STATE_NO_SERVER, client.Servers, false) + return server, nil } -// addInstituteServer adds an Institute Access server by `url`. -func (client *Client) addInstituteServer(url string) (server.Server, error) { - errorMessage := fmt.Sprintf("failed adding Institute Access server with url %s", url) - instituteServer, discoErr := client.Discovery.GetServerByURL(url, "institute_access") +// AddSecureInternetHomeServer adds a Secure Internet Home Server with `orgID` that was obtained from the Discovery file. +// Because there is only one Secure Internet Home Server, it replaces the existing one. +func (client *Client) AddSecureInternetHomeServer(orgID string) (server.Server, error) { + errorMessage := fmt.Sprintf( + "failed adding Secure Internet home server with organization ID %s", + orgID, + ) + + // Not supported with Let's Connect! + if client.isLetsConnect() { + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: LetsConnectNotSupportedError{}} + } + + // Indicate that we're loading the server + client.FSM.GoTransition(STATE_LOADING_SERVER) + + // Get the secure internet URL from discovery + secureOrg, secureServer, discoErr := client.Discovery.GetSecureHomeArgs(orgID) if discoErr != nil { + client.goBackInternal() return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr} } + // Add the secure internet server - server, serverErr := client.Servers.AddInstituteAccessServer(instituteServer) + server, serverErr := client.Servers.AddSecureInternet(secureOrg, secureServer) if serverErr != nil { + client.goBackInternal() return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } + locationErr := client.askSecureLocation() + if locationErr != nil { + // Removing is best effort + _ = client.RemoveSecureInternet() + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: locationErr} + } + + // Server has been chosen for authentication client.FSM.GoTransition(STATE_CHOSEN_SERVER) + // Authorize it + loginErr := client.ensureLogin(server) + if loginErr != nil { + // Removing is best effort + _ = client.RemoveSecureInternet() + return nil, &types.WrappedErrorMessage{Level: types.GetErrorLevel(loginErr), Message: errorMessage, Err: loginErr} + } + client.FSM.GoTransitionWithData(STATE_NO_SERVER, client.Servers, false) return server, nil } -// addCustomServer adds a Custom Server by `url` -func (client *Client) addCustomServer(url string) (server.Server, error) { +// AddCustomServer adds a Custom Server by `url` +func (client *Client) AddCustomServer(url string) (server.Server, error) { errorMessage := fmt.Sprintf("failed adding Custom server with url %s", url) url, urlErr := util.EnsureValidURL(url) @@ -555,6 +548,9 @@ func (client *Client) addCustomServer(url string) (server.Server, error) { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} } + // Indicate that we're loading the server + client.FSM.GoTransition(STATE_LOADING_SERVER) + customServer := &types.DiscoveryServer{ BaseURL: url, DisplayName: map[string]string{"en": url}, @@ -564,11 +560,22 @@ func (client *Client) addCustomServer(url string) (server.Server, error) { // A custom server is just an institute access server under the hood server, serverErr := client.Servers.AddCustomServer(customServer) if serverErr != nil { + client.goBackInternal() return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } + // Server has been chosen for authentication client.FSM.GoTransition(STATE_CHOSEN_SERVER) + // Authorize it + loginErr := client.ensureLogin(server) + if loginErr != nil { + // removing is best effort + _ = client.RemoveCustomServer(url) + return nil, &types.WrappedErrorMessage{Level: types.GetErrorLevel(loginErr), Message: errorMessage, Err: loginErr} + } + + client.FSM.GoTransitionWithData(STATE_NO_SERVER, client.Servers, false) return server, nil } @@ -584,60 +591,120 @@ func (client *Client) GetConfigInstituteAccess(url string, preferTCP bool) (stri } client.FSM.GoTransition(STATE_LOADING_SERVER) - server, serverErr := client.addInstituteServer(url) + + // Get the server if it exists + server, serverErr := client.Servers.GetInstituteAccess(url) if serverErr != nil { client.Logger.Error( fmt.Sprintf( - "Failed adding an institute access server with error: %s", - GetErrorTraceback(serverErr), + "Failed getting an institute access server configuration with error: %s", + types.GetErrorTraceback(serverErr), ), ) - client.goBackInternal() return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } + // The server has now been chosen + client.FSM.GoTransition(STATE_CHOSEN_SERVER) + config, configType, configErr := client.getConfig(server, preferTCP) if configErr != nil { client.Logger.Inherit(configErr, fmt.Sprintf( "Failed getting an institute access server configuration with error: %s", - GetErrorTraceback(configErr), + types.GetErrorTraceback(configErr), ), ) - return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} + return "", "", &types.WrappedErrorMessage{Level: types.GetErrorLevel(configErr), Message: errorMessage, Err: configErr} } return config, configType, nil } +// GetConfigSecureInternet gets a configuration for a Secure Internet Server. +// It ensures that the Secure Internet Server exists by creating or using an existing one with the orgID. +// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. +func (client *Client) GetConfigSecureInternet( + orgID string, + preferTCP bool, +) (string, string, error) { + errorMessage := fmt.Sprintf( + "failed getting a configuration for Secure Internet organization %s", + orgID, + ) + + // Not supported with Let's Connect! + if client.isLetsConnect() { + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: LetsConnectNotSupportedError{}} + } + + client.FSM.GoTransition(STATE_LOADING_SERVER) + + // Get the server if it exists + server, serverErr := client.Servers.GetSecureInternetHomeServer() + if serverErr != nil { + client.Logger.Error( + fmt.Sprintf( + "Failed getting a custom server configuration with error: %s", + types.GetErrorTraceback(serverErr), + ), + ) + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + + client.FSM.GoTransition(STATE_CHOSEN_SERVER) + + config, configType, configErr := client.getConfig(server, preferTCP) + if configErr != nil { + client.Logger.Inherit( + configErr, + fmt.Sprintf( + "Failed getting a secure internet configuration with error: %s", + types.GetErrorTraceback(configErr), + ), + ) + return "", "", &types.WrappedErrorMessage{Level: types.GetErrorLevel(configErr), Message: errorMessage, Err: configErr} + } + return config, configType, nil +} + + // GetConfigCustomServer gets a configuration for a Custom Server. // It ensures that the Custom Server exists by creating or using an existing one with the url. // `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. func (client *Client) GetConfigCustomServer(url string, preferTCP bool) (string, string, error) { errorMessage := fmt.Sprintf("failed getting a configuration for custom server %s", url) + + url, urlErr := util.EnsureValidURL(url) + if urlErr != nil { + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} + } + client.FSM.GoTransition(STATE_LOADING_SERVER) - server, serverErr := client.addCustomServer(url) + // Get the server if it exists + server, serverErr := client.Servers.GetCustomServer(url) if serverErr != nil { client.Logger.Error( fmt.Sprintf( - "Failed adding a custom server with error: %s", - GetErrorTraceback(serverErr), + "Failed getting a custom server configuration with error: %s", + types.GetErrorTraceback(serverErr), ), ) - client.goBackInternal() return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } + client.FSM.GoTransition(STATE_CHOSEN_SERVER) + config, configType, configErr := client.getConfig(server, preferTCP) if configErr != nil { client.Logger.Inherit( configErr, fmt.Sprintf( "Failed getting a custom server with error: %s", - GetErrorTraceback(configErr), + types.GetErrorTraceback(configErr), ), ) - return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} + return "", "", &types.WrappedErrorMessage{Level: types.GetErrorLevel(configErr), Message: errorMessage, Err: configErr} } return config, configType, nil } @@ -694,7 +761,7 @@ func (client *Client) ChangeSecureLocation() error { client.Logger.Error( fmt.Sprintf( "Failed changing secure internet location, err: %s", - GetErrorTraceback(askLocationErr), + types.GetErrorTraceback(askLocationErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: askLocationErr} @@ -722,7 +789,7 @@ func (client *Client) GetDiscoOrganizations() (*types.DiscoveryOrganizations, er client.Logger.Warning( fmt.Sprintf( "Failed getting discovery organizations, Err: %s", - GetErrorTraceback(orgsErr), + types.GetErrorTraceback(orgsErr), ), ) return nil, &types.WrappedErrorMessage{ @@ -733,7 +800,7 @@ func (client *Client) GetDiscoOrganizations() (*types.DiscoveryOrganizations, er return orgs, nil } -// GetDiscoDiscovers gets the servers list from the discovery server +// GetDiscoServers gets the servers list from the discovery server // If the list cannot be retrieved an error is returned. // If this is the case then a previous version of the list is returned if there is any. // This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#server-list. @@ -748,7 +815,7 @@ func (client *Client) GetDiscoServers() (*types.DiscoveryServers, error) { servers, serversErr := client.Discovery.GetServersList() if serversErr != nil { client.Logger.Warning( - fmt.Sprintf("Failed getting discovery servers, Err: %s", GetErrorTraceback(serversErr)), + fmt.Sprintf("Failed getting discovery servers, Err: %s", types.GetErrorTraceback(serversErr)), ) return nil, &types.WrappedErrorMessage{ Message: errorMessage, @@ -767,7 +834,7 @@ func (client *Client) SetProfileID(profileID string) error { client.Logger.Warning( fmt.Sprintf( "Failed setting a profile ID because no server configured, Err: %s", - GetErrorTraceback(serverErr), + types.GetErrorTraceback(serverErr), ), ) client.goBackInternal() @@ -777,7 +844,7 @@ func (client *Client) SetProfileID(profileID string) error { base, baseErr := server.GetBase() if baseErr != nil { client.Logger.Error( - fmt.Sprintf("Failed setting a profile ID, Err: %s", GetErrorTraceback(serverErr)), + fmt.Sprintf("Failed setting a profile ID, Err: %s", types.GetErrorTraceback(serverErr)), ) client.goBackInternal() return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} @@ -801,7 +868,7 @@ func (client *Client) SetSearchServer() error { Message: "failed to set search server", Err: FSMWrongStateTransitionError{ Got: client.FSM.Current, - Want: STATE_CONNECTED, + Want: STATE_SEARCH_SERVER, }.CustomError(), } } @@ -841,7 +908,7 @@ func (client *Client) SetConnected() error { client.Logger.Warning( fmt.Sprintf( "Failed setting connected, cannot get current server with error: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} @@ -882,7 +949,7 @@ func (client *Client) SetConnecting() error { client.Logger.Warning( fmt.Sprintf( "Failed setting connecting, cannot get current server with error: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} @@ -923,7 +990,7 @@ func (client *Client) SetDisconnecting() error { client.Logger.Warning( fmt.Sprintf( "Failed setting disconnected, cannot get current server with error: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} @@ -965,7 +1032,7 @@ func (client *Client) SetDisconnected(cleanup bool) error { client.Logger.Warning( fmt.Sprintf( "Failed setting disconnect, failed getting current server with error: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} @@ -991,7 +1058,7 @@ func (client *Client) RenewSession() error { client.Logger.Warning( fmt.Sprintf( "Failed getting current server to renew, error: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} @@ -1002,7 +1069,7 @@ func (client *Client) RenewSession() error { client.Logger.Warning( fmt.Sprintf( "Failed logging in server for renew, error: %s", - GetErrorTraceback(loginErr), + types.GetErrorTraceback(loginErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr} @@ -1027,7 +1094,7 @@ func (client *Client) ShouldRenewButton() bool { client.Logger.Info( fmt.Sprintf( "No server found to renew with err: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return false @@ -1041,27 +1108,12 @@ func (client *Client) InFSMState(checkState FSMStateID) bool { return client.FSM.InState(checkState) } -// GetErrorCause gets the cause for error `err`. -func GetErrorCause(err error) error { - return types.GetErrorCause(err) -} - -// GetErrorCause gets the level for error `err`. -func GetErrorLevel(err error) types.ErrorLevel { - return types.GetErrorLevel(err) -} - -// GetErrorCause gets the traceback for error `err`. -func GetErrorTraceback(err error) string { - return types.GetErrorTraceback(err) -} - // GetTranslated gets the translation for `languages` using the current state language. func (client *Client) GetTranslated(languages map[string]string) string { return util.GetLanguageMatched(languages, client.Language) } -type LetsConnectNotSupportedError struct {} +type LetsConnectNotSupportedError struct{} func (e LetsConnectNotSupportedError) Error() string { return "Any operation that involves discovery is not allowed with the Let's Connect! client" |
