From 565237c14a303a46d62d240b35c6f0082424256a Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Tue, 11 Oct 2022 10:19:56 +0200 Subject: Client: Refactor out adding a Server from getting a config --- client.go | 306 +++++++++++++++++++------------- client_test.go | 31 +++- cmd/cli/main.go | 17 +- exports/exports.go | 43 ++++- fsm.go | 1 + internal/server/custom.go | 12 ++ internal/server/instituteaccess.go | 7 + internal/server/secureinternet.go | 10 +- wrappers/python/eduvpn_common/loader.py | 12 ++ wrappers/python/eduvpn_common/main.py | 20 ++- wrappers/python/main.py | 3 +- wrappers/python/tests.py | 1 + 12 files changed, 325 insertions(+), 138 deletions(-) diff --git a/client.go b/client.go index 9237b04..5745d97 100644 --- a/client.go +++ b/client.go @@ -11,8 +11,8 @@ import ( "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server" - "github.com/eduvpn/eduvpn-common/types" "github.com/eduvpn/eduvpn-common/internal/util" + "github.com/eduvpn/eduvpn-common/types" ) type ( @@ -134,7 +134,7 @@ func (client *Client) Deregister() { client.Logger.Info( fmt.Sprintf( "Failed saving configuration, error: %s", - GetErrorTraceback(saveErr), + types.GetErrorTraceback(saveErr), ), ) } @@ -150,7 +150,7 @@ func (client *Client) goBackInternal() { client.Logger.Info( fmt.Sprintf( "Failed going back, error: %s", - GetErrorTraceback(goBackErr), + types.GetErrorTraceback(goBackErr), ), ) } @@ -179,7 +179,7 @@ func (client *Client) ensureLogin(chosenServer server.Server) error { // Relogin with oauth // This moves the state to authorized if server.NeedsRelogin(chosenServer) { - url, urlErr := server.GetOAuthURL(chosenServer, client.FSM.Name) + url, urlErr := server.GetOAuthURL(chosenServer, client.Name) client.FSM.GoTransitionWithData(STATE_OAUTH_STARTED, url, true) @@ -277,7 +277,7 @@ func (client *Client) getConfig( config, configType, configErr := client.retryConfigAuth(chosenServer, preferTCP) if configErr != nil { - return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} + return "", "", &types.WrappedErrorMessage{Level: types.GetErrorLevel(configErr), Message: errorMessage, Err: configErr} } currentServer, currentServerErr := client.Servers.GetCurrentServer() @@ -294,7 +294,7 @@ func (client *Client) getConfig( client.Logger.Info( fmt.Sprintf( "Failed saving configuration after getting a server: %s", - GetErrorTraceback(saveErr), + types.GetErrorTraceback(saveErr), ), ) } @@ -318,7 +318,7 @@ func (client *Client) SetSecureLocation(countryCode string) error { fmt.Sprintf( "Failed getting secure internet server by country code: %s with error: %s", countryCode, - GetErrorTraceback(serverErr), + types.GetErrorTraceback(serverErr), ), ) client.goBackInternal() @@ -330,7 +330,7 @@ func (client *Client) SetSecureLocation(countryCode string) error { client.Logger.Error( fmt.Sprintf( "Failed setting secure internet server with error: %s", - GetErrorTraceback(setLocationErr), + types.GetErrorTraceback(setLocationErr), ), ) client.goBackInternal() @@ -367,42 +367,6 @@ func (client *Client) askSecureLocation() error { return nil } -// addSecureInternetHomeServer adds a Secure Internet Home Server with `orgID` that was obtained from the Discovery file. -// Because there is only one Secure Internet Home Server, it replaces the existing one. -func (client *Client) addSecureInternetHomeServer(orgID string) (server.Server, error) { - errorMessage := fmt.Sprintf( - "failed adding Secure Internet home server with organization ID %s", - orgID, - ) - // Get the secure internet URL from discovery - secureOrg, secureServer, discoErr := client.Discovery.GetSecureHomeArgs(orgID) - if discoErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr} - } - - // Add the secure internet server - server, serverErr := client.Servers.AddSecureInternet(secureOrg, secureServer) - - if serverErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} - } - - var locationErr error - - if !client.Servers.HasSecureLocation() { - locationErr = client.askSecureLocation() - } else { - // reinitialize - locationErr = client.SetSecureLocation(client.Servers.GetSecureLocation()) - } - - if locationErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: locationErr} - } - - return server, nil -} - // RemoveSecureInternet removes the current secure internet server. // It returns an error if the server cannot be removed due to the state being DEREGISTERED. // Note that if the server does not exist, it returns nil as an error. @@ -423,7 +387,7 @@ func (client *Client) RemoveSecureInternet() error { client.Logger.Info( fmt.Sprintf( "Failed saving configuration after removing a secure internet server: %s", - GetErrorTraceback(saveErr), + types.GetErrorTraceback(saveErr), ), ) } @@ -449,7 +413,7 @@ func (client *Client) RemoveInstituteAccess(url string) error { client.Logger.Info( fmt.Sprintf( "Failed saving configuration after removing an institute access server: %s", - GetErrorTraceback(saveErr), + types.GetErrorTraceback(saveErr), ), ) } @@ -475,79 +439,108 @@ func (client *Client) RemoveCustomServer(url string) error { client.Logger.Info( fmt.Sprintf( "Failed saving configuration after removing a custom server: %s", - GetErrorTraceback(saveErr), + types.GetErrorTraceback(saveErr), ), ) } return nil } -// GetConfigSecureInternet gets a configuration for a Secure Internet Server. -// It ensures that the Secure Internet Server exists by creating or using an existing one with the orgID. -// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. -func (client *Client) GetConfigSecureInternet( - orgID string, - preferTCP bool, -) (string, string, error) { - errorMessage := fmt.Sprintf( - "failed getting a configuration for Secure Internet organization %s", - orgID, - ) +// AddInstituteServer adds an Institute Access server by `url`. +func (client *Client) AddInstituteServer(url string) (server.Server, error) { + errorMessage := fmt.Sprintf("failed adding Institute Access server with url %s", url) // Not supported with Let's Connect! if client.isLetsConnect() { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: LetsConnectNotSupportedError{}} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: LetsConnectNotSupportedError{}} } + // Indicate that we're loading the server client.FSM.GoTransition(STATE_LOADING_SERVER) - server, serverErr := client.addSecureInternetHomeServer(orgID) + + // FIXME: Do nothing with discovery here as the client already has it + // So pass a server as the parameter + instituteServer, discoErr := client.Discovery.GetServerByURL(url, "institute_access") + if discoErr != nil { + client.goBackInternal() + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr} + } + + // Add the secure internet server + server, serverErr := client.Servers.AddInstituteAccessServer(instituteServer) if serverErr != nil { - client.Logger.Error( - fmt.Sprintf( - "Failed adding a secure internet server with error: %s", - GetErrorTraceback(serverErr), - ), - ) client.goBackInternal() - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } + // Indicate that we want to authorize this server client.FSM.GoTransition(STATE_CHOSEN_SERVER) - config, configType, configErr := client.getConfig(server, preferTCP) - if configErr != nil { - client.Logger.Inherit( - configErr, - fmt.Sprintf( - "Failed getting a secure internet configuration with error: %s", - GetErrorTraceback(configErr), - ), - ) - return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} + // Authorize it + loginErr := client.ensureLogin(server) + if loginErr != nil { + // Removing is best effort + _ = client.RemoveInstituteAccess(url) + return nil, &types.WrappedErrorMessage{Level: types.GetErrorLevel(loginErr), Message: errorMessage, Err: loginErr} } - return config, configType, nil + + client.FSM.GoTransitionWithData(STATE_NO_SERVER, client.Servers, false) + return server, nil } -// addInstituteServer adds an Institute Access server by `url`. -func (client *Client) addInstituteServer(url string) (server.Server, error) { - errorMessage := fmt.Sprintf("failed adding Institute Access server with url %s", url) - instituteServer, discoErr := client.Discovery.GetServerByURL(url, "institute_access") +// AddSecureInternetHomeServer adds a Secure Internet Home Server with `orgID` that was obtained from the Discovery file. +// Because there is only one Secure Internet Home Server, it replaces the existing one. +func (client *Client) AddSecureInternetHomeServer(orgID string) (server.Server, error) { + errorMessage := fmt.Sprintf( + "failed adding Secure Internet home server with organization ID %s", + orgID, + ) + + // Not supported with Let's Connect! + if client.isLetsConnect() { + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: LetsConnectNotSupportedError{}} + } + + // Indicate that we're loading the server + client.FSM.GoTransition(STATE_LOADING_SERVER) + + // Get the secure internet URL from discovery + secureOrg, secureServer, discoErr := client.Discovery.GetSecureHomeArgs(orgID) if discoErr != nil { + client.goBackInternal() return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr} } + // Add the secure internet server - server, serverErr := client.Servers.AddInstituteAccessServer(instituteServer) + server, serverErr := client.Servers.AddSecureInternet(secureOrg, secureServer) if serverErr != nil { + client.goBackInternal() return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } + locationErr := client.askSecureLocation() + if locationErr != nil { + // Removing is best effort + _ = client.RemoveSecureInternet() + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: locationErr} + } + + // Server has been chosen for authentication client.FSM.GoTransition(STATE_CHOSEN_SERVER) + // Authorize it + loginErr := client.ensureLogin(server) + if loginErr != nil { + // Removing is best effort + _ = client.RemoveSecureInternet() + return nil, &types.WrappedErrorMessage{Level: types.GetErrorLevel(loginErr), Message: errorMessage, Err: loginErr} + } + client.FSM.GoTransitionWithData(STATE_NO_SERVER, client.Servers, false) return server, nil } -// addCustomServer adds a Custom Server by `url` -func (client *Client) addCustomServer(url string) (server.Server, error) { +// AddCustomServer adds a Custom Server by `url` +func (client *Client) AddCustomServer(url string) (server.Server, error) { errorMessage := fmt.Sprintf("failed adding Custom server with url %s", url) url, urlErr := util.EnsureValidURL(url) @@ -555,6 +548,9 @@ func (client *Client) addCustomServer(url string) (server.Server, error) { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} } + // Indicate that we're loading the server + client.FSM.GoTransition(STATE_LOADING_SERVER) + customServer := &types.DiscoveryServer{ BaseURL: url, DisplayName: map[string]string{"en": url}, @@ -564,11 +560,22 @@ func (client *Client) addCustomServer(url string) (server.Server, error) { // A custom server is just an institute access server under the hood server, serverErr := client.Servers.AddCustomServer(customServer) if serverErr != nil { + client.goBackInternal() return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } + // Server has been chosen for authentication client.FSM.GoTransition(STATE_CHOSEN_SERVER) + // Authorize it + loginErr := client.ensureLogin(server) + if loginErr != nil { + // removing is best effort + _ = client.RemoveCustomServer(url) + return nil, &types.WrappedErrorMessage{Level: types.GetErrorLevel(loginErr), Message: errorMessage, Err: loginErr} + } + + client.FSM.GoTransitionWithData(STATE_NO_SERVER, client.Servers, false) return server, nil } @@ -584,60 +591,120 @@ func (client *Client) GetConfigInstituteAccess(url string, preferTCP bool) (stri } client.FSM.GoTransition(STATE_LOADING_SERVER) - server, serverErr := client.addInstituteServer(url) + + // Get the server if it exists + server, serverErr := client.Servers.GetInstituteAccess(url) if serverErr != nil { client.Logger.Error( fmt.Sprintf( - "Failed adding an institute access server with error: %s", - GetErrorTraceback(serverErr), + "Failed getting an institute access server configuration with error: %s", + types.GetErrorTraceback(serverErr), ), ) - client.goBackInternal() return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } + // The server has now been chosen + client.FSM.GoTransition(STATE_CHOSEN_SERVER) + config, configType, configErr := client.getConfig(server, preferTCP) if configErr != nil { client.Logger.Inherit(configErr, fmt.Sprintf( "Failed getting an institute access server configuration with error: %s", - GetErrorTraceback(configErr), + types.GetErrorTraceback(configErr), ), ) - return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} + return "", "", &types.WrappedErrorMessage{Level: types.GetErrorLevel(configErr), Message: errorMessage, Err: configErr} } return config, configType, nil } +// GetConfigSecureInternet gets a configuration for a Secure Internet Server. +// It ensures that the Secure Internet Server exists by creating or using an existing one with the orgID. +// `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. +func (client *Client) GetConfigSecureInternet( + orgID string, + preferTCP bool, +) (string, string, error) { + errorMessage := fmt.Sprintf( + "failed getting a configuration for Secure Internet organization %s", + orgID, + ) + + // Not supported with Let's Connect! + if client.isLetsConnect() { + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: LetsConnectNotSupportedError{}} + } + + client.FSM.GoTransition(STATE_LOADING_SERVER) + + // Get the server if it exists + server, serverErr := client.Servers.GetSecureInternetHomeServer() + if serverErr != nil { + client.Logger.Error( + fmt.Sprintf( + "Failed getting a custom server configuration with error: %s", + types.GetErrorTraceback(serverErr), + ), + ) + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + + client.FSM.GoTransition(STATE_CHOSEN_SERVER) + + config, configType, configErr := client.getConfig(server, preferTCP) + if configErr != nil { + client.Logger.Inherit( + configErr, + fmt.Sprintf( + "Failed getting a secure internet configuration with error: %s", + types.GetErrorTraceback(configErr), + ), + ) + return "", "", &types.WrappedErrorMessage{Level: types.GetErrorLevel(configErr), Message: errorMessage, Err: configErr} + } + return config, configType, nil +} + + // GetConfigCustomServer gets a configuration for a Custom Server. // It ensures that the Custom Server exists by creating or using an existing one with the url. // `preferTCP` indicates that the client wants to use TCP (through OpenVPN) to establish the VPN tunnel. func (client *Client) GetConfigCustomServer(url string, preferTCP bool) (string, string, error) { errorMessage := fmt.Sprintf("failed getting a configuration for custom server %s", url) + + url, urlErr := util.EnsureValidURL(url) + if urlErr != nil { + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} + } + client.FSM.GoTransition(STATE_LOADING_SERVER) - server, serverErr := client.addCustomServer(url) + // Get the server if it exists + server, serverErr := client.Servers.GetCustomServer(url) if serverErr != nil { client.Logger.Error( fmt.Sprintf( - "Failed adding a custom server with error: %s", - GetErrorTraceback(serverErr), + "Failed getting a custom server configuration with error: %s", + types.GetErrorTraceback(serverErr), ), ) - client.goBackInternal() return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } + client.FSM.GoTransition(STATE_CHOSEN_SERVER) + config, configType, configErr := client.getConfig(server, preferTCP) if configErr != nil { client.Logger.Inherit( configErr, fmt.Sprintf( "Failed getting a custom server with error: %s", - GetErrorTraceback(configErr), + types.GetErrorTraceback(configErr), ), ) - return "", "", &types.WrappedErrorMessage{Level: GetErrorLevel(configErr), Message: errorMessage, Err: configErr} + return "", "", &types.WrappedErrorMessage{Level: types.GetErrorLevel(configErr), Message: errorMessage, Err: configErr} } return config, configType, nil } @@ -694,7 +761,7 @@ func (client *Client) ChangeSecureLocation() error { client.Logger.Error( fmt.Sprintf( "Failed changing secure internet location, err: %s", - GetErrorTraceback(askLocationErr), + types.GetErrorTraceback(askLocationErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: askLocationErr} @@ -722,7 +789,7 @@ func (client *Client) GetDiscoOrganizations() (*types.DiscoveryOrganizations, er client.Logger.Warning( fmt.Sprintf( "Failed getting discovery organizations, Err: %s", - GetErrorTraceback(orgsErr), + types.GetErrorTraceback(orgsErr), ), ) return nil, &types.WrappedErrorMessage{ @@ -733,7 +800,7 @@ func (client *Client) GetDiscoOrganizations() (*types.DiscoveryOrganizations, er return orgs, nil } -// GetDiscoDiscovers gets the servers list from the discovery server +// GetDiscoServers gets the servers list from the discovery server // If the list cannot be retrieved an error is returned. // If this is the case then a previous version of the list is returned if there is any. // This takes into account the frequency of updates, see: https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md#server-list. @@ -748,7 +815,7 @@ func (client *Client) GetDiscoServers() (*types.DiscoveryServers, error) { servers, serversErr := client.Discovery.GetServersList() if serversErr != nil { client.Logger.Warning( - fmt.Sprintf("Failed getting discovery servers, Err: %s", GetErrorTraceback(serversErr)), + fmt.Sprintf("Failed getting discovery servers, Err: %s", types.GetErrorTraceback(serversErr)), ) return nil, &types.WrappedErrorMessage{ Message: errorMessage, @@ -767,7 +834,7 @@ func (client *Client) SetProfileID(profileID string) error { client.Logger.Warning( fmt.Sprintf( "Failed setting a profile ID because no server configured, Err: %s", - GetErrorTraceback(serverErr), + types.GetErrorTraceback(serverErr), ), ) client.goBackInternal() @@ -777,7 +844,7 @@ func (client *Client) SetProfileID(profileID string) error { base, baseErr := server.GetBase() if baseErr != nil { client.Logger.Error( - fmt.Sprintf("Failed setting a profile ID, Err: %s", GetErrorTraceback(serverErr)), + fmt.Sprintf("Failed setting a profile ID, Err: %s", types.GetErrorTraceback(serverErr)), ) client.goBackInternal() return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} @@ -801,7 +868,7 @@ func (client *Client) SetSearchServer() error { Message: "failed to set search server", Err: FSMWrongStateTransitionError{ Got: client.FSM.Current, - Want: STATE_CONNECTED, + Want: STATE_SEARCH_SERVER, }.CustomError(), } } @@ -841,7 +908,7 @@ func (client *Client) SetConnected() error { client.Logger.Warning( fmt.Sprintf( "Failed setting connected, cannot get current server with error: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} @@ -882,7 +949,7 @@ func (client *Client) SetConnecting() error { client.Logger.Warning( fmt.Sprintf( "Failed setting connecting, cannot get current server with error: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} @@ -923,7 +990,7 @@ func (client *Client) SetDisconnecting() error { client.Logger.Warning( fmt.Sprintf( "Failed setting disconnected, cannot get current server with error: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} @@ -965,7 +1032,7 @@ func (client *Client) SetDisconnected(cleanup bool) error { client.Logger.Warning( fmt.Sprintf( "Failed setting disconnect, failed getting current server with error: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} @@ -991,7 +1058,7 @@ func (client *Client) RenewSession() error { client.Logger.Warning( fmt.Sprintf( "Failed getting current server to renew, error: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} @@ -1002,7 +1069,7 @@ func (client *Client) RenewSession() error { client.Logger.Warning( fmt.Sprintf( "Failed logging in server for renew, error: %s", - GetErrorTraceback(loginErr), + types.GetErrorTraceback(loginErr), ), ) return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr} @@ -1027,7 +1094,7 @@ func (client *Client) ShouldRenewButton() bool { client.Logger.Info( fmt.Sprintf( "No server found to renew with err: %s", - GetErrorTraceback(currentServerErr), + types.GetErrorTraceback(currentServerErr), ), ) return false @@ -1041,27 +1108,12 @@ func (client *Client) InFSMState(checkState FSMStateID) bool { return client.FSM.InState(checkState) } -// GetErrorCause gets the cause for error `err`. -func GetErrorCause(err error) error { - return types.GetErrorCause(err) -} - -// GetErrorCause gets the level for error `err`. -func GetErrorLevel(err error) types.ErrorLevel { - return types.GetErrorLevel(err) -} - -// GetErrorCause gets the traceback for error `err`. -func GetErrorTraceback(err error) string { - return types.GetErrorTraceback(err) -} - // GetTranslated gets the translation for `languages` using the current state language. func (client *Client) GetTranslated(languages map[string]string) string { return util.GetLanguageMatched(languages, client.Language) } -type LetsConnectNotSupportedError struct {} +type LetsConnectNotSupportedError struct{} func (e LetsConnectNotSupportedError) Error() string { return "Any operation that involves discovery is not allowed with the Let's Connect! client" diff --git a/client_test.go b/client_test.go index e618d07..87d00f7 100644 --- a/client_test.go +++ b/client_test.go @@ -85,6 +85,10 @@ func Test_server(t *testing.T) { t.Fatalf("Register error: %v", registerErr) } + _, addErr := state.AddCustomServer(serverURI) + if addErr != nil { + t.Fatalf("Add error: %v", addErr) + } _, _, configErr := state.GetConfigCustomServer(serverURI, false) if configErr != nil { t.Fatalf("Connect error: %v", configErr) @@ -137,13 +141,14 @@ func test_connect_oauth_parameter( if registerErr != nil { t.Fatalf("Register error: %v", registerErr) } - _, _, configErr := state.GetConfigCustomServer(serverURI, false) + + _, addErr := state.AddCustomServer(serverURI) var wrappedErr *types.WrappedErrorMessage // We ensure the error is of a wrappedErrorMessage - if !errors.As(configErr, &wrappedErr) { - t.Fatalf("error %T = %v, wantErr %T", configErr, configErr, wrappedErr) + if !errors.As(addErr, &wrappedErr) { + t.Fatalf("error %T = %v, wantErr %T", addErr, addErr, wrappedErr) } gotExpectedErr := wrappedErr.Cause() @@ -206,6 +211,11 @@ func Test_token_expired(t *testing.T) { t.Fatalf("Register error: %v", registerErr) } + _, addErr := state.AddCustomServer(serverURI) + if addErr != nil { + t.Fatalf("Add error: %v", addErr) + } + _, _, configErr := state.GetConfigCustomServer(serverURI, false) if configErr != nil { @@ -261,6 +271,11 @@ func Test_token_invalid(t *testing.T) { t.Fatalf("Register error: %v", registerErr) } + _, addErr := state.AddCustomServer(serverURI) + if addErr != nil { + t.Fatalf("Add error: %v", addErr) + } + _, _, configErr := state.GetConfigCustomServer(serverURI, false) if configErr != nil { @@ -313,6 +328,11 @@ func Test_invalid_profile_corrected(t *testing.T) { t.Fatalf("Register error: %v", registerErr) } + _, addErr := state.AddCustomServer(serverURI) + if addErr != nil { + t.Fatalf("Add error: %v", addErr) + } + _, _, configErr := state.GetConfigCustomServer(serverURI, false) if configErr != nil { @@ -365,6 +385,11 @@ func Test_prefer_tcp(t *testing.T) { t.Fatalf("Register error: %v", registerErr) } + _, addErr := state.AddCustomServer(serverURI) + if addErr != nil { + t.Fatalf("Add error: %v", addErr) + } + // get a config with preferTCP set to true config, configType, configErr := state.GetConfigCustomServer(serverURI, true) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 14ab6ea..dbded2d 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -8,6 +8,7 @@ import ( eduvpn "github.com/eduvpn/eduvpn-common" "github.com/eduvpn/eduvpn-common/internal/server" + "github.com/eduvpn/eduvpn-common/types" ) type ServerTypes int8 @@ -97,10 +98,22 @@ func getConfig(state *eduvpn.Client, url string, serverType ServerTypes) (string } // Prefer TCP is set to False if serverType == ServerTypeInstituteAccess { + _, addErr := state.AddInstituteServer(url) + if addErr != nil { + return "", "", addErr + } return state.GetConfigInstituteAccess(url, false) } else if serverType == ServerTypeCustom { + _, addErr := state.AddCustomServer(url) + if addErr != nil { + return "", "", addErr + } return state.GetConfigCustomServer(url, false) } + _, addErr := state.AddSecureInternetHomeServer(url) + if addErr != nil { + return "", "", addErr + } return state.GetConfigSecureInternet(url, false) } @@ -128,8 +141,8 @@ func printConfig(url string, serverType ServerTypes) { if configErr != nil { // Show the usage of tracebacks and causes - fmt.Println("Error getting config:", eduvpn.GetErrorTraceback(configErr)) - fmt.Println("Error getting config, cause:", eduvpn.GetErrorCause(configErr)) + fmt.Println("Error getting config:", types.GetErrorTraceback(configErr)) + fmt.Println("Error getting config, cause:", types.GetErrorCause(configErr)) return } diff --git a/exports/exports.go b/exports/exports.go index 5b0a619..e319e87 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -15,6 +15,7 @@ import "C" import ( "fmt" + "github.com/eduvpn/eduvpn-common/types" "unsafe" eduvpn "github.com/eduvpn/eduvpn-common" @@ -139,9 +140,9 @@ func getError(err error) *C.error { errorStruct := (*C.error)( C.malloc(C.size_t(unsafe.Sizeof(C.error{}))), ) - errorStruct.level = C.errorLevel(eduvpn.GetErrorLevel(err)) - errorStruct.traceback = C.CString(eduvpn.GetErrorTraceback(err)) - errorStruct.cause = C.CString(eduvpn.GetErrorCause(err).Error()) + errorStruct.level = C.errorLevel(types.GetErrorLevel(err)) + errorStruct.traceback = C.CString(types.GetErrorTraceback(err)) + errorStruct.cause = C.CString(types.GetErrorCause(err).Error()) return errorStruct } @@ -174,6 +175,42 @@ func RemoveSecureInternet(name *C.char) *C.error { return getError(removeErr) } +//export AddInstituteAccess +func AddInstituteAccess(name *C.char, url *C.char) *C.error { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + if stateErr != nil { + return getError(stateErr) + } + // FIXME: Return server result + _, addErr := state.AddInstituteServer(C.GoString(url)) + return getError(addErr) +} + +//export AddSecureInternetHomeServer +func AddSecureInternetHomeServer(name *C.char, orgID *C.char) *C.error { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + if stateErr != nil { + return getError(stateErr) + } + // FIXME: Return server result + _, addErr := state.AddSecureInternetHomeServer(C.GoString(orgID)) + return getError(addErr) +} + +//export AddCustomServer +func AddCustomServer(name *C.char, url *C.char) *C.error { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + if stateErr != nil { + return getError(stateErr) + } + // FIXME: Return server result + _, addErr := state.AddCustomServer(C.GoString(url)) + return getError(addErr) +} + //export RemoveInstituteAccess func RemoveInstituteAccess(name *C.char, url *C.char) *C.error { nameStr := C.GoString(name) diff --git a/fsm.go b/fsm.go index 8f60605..ae5bfdc 100644 --- a/fsm.go +++ b/fsm.go @@ -157,6 +157,7 @@ func newFSM( Transitions: []FSMTransition{ {To: STATE_OAUTH_STARTED, Description: "Re-authorize with OAuth"}, {To: STATE_REQUEST_CONFIG, Description: "Client requests a config"}, + {To: STATE_NO_SERVER, Description: "Client wants to go back to the main screen"}, }, }, STATE_REQUEST_CONFIG: FSMState{ diff --git a/internal/server/custom.go b/internal/server/custom.go index a93242d..52a0094 100644 --- a/internal/server/custom.go +++ b/internal/server/custom.go @@ -1,5 +1,17 @@ package server +import ( + "fmt" + "github.com/eduvpn/eduvpn-common/types" +) + +func (servers *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) { + if server, ok := servers.CustomServers.Map[url]; ok { + return server, nil + } + return nil, &types.WrappedErrorMessage{Message: "failed to get institute access server", Err: fmt.Errorf("No custom server with URL: %s", url)} +} + func (servers *Servers) RemoveCustomServer(url string) { servers.CustomServers.Remove(url) } diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index c5b58ef..f2669b8 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -21,6 +21,13 @@ type InstituteAccessServers struct { CurrentURL string `json:"current_url"` } +func (servers *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) { + if server, ok := servers.InstituteServers.Map[url]; ok { + return server, nil + } + return nil, &types.WrappedErrorMessage{Message: "failed to get institute access server", Err: fmt.Errorf("No institute access server with URL: %s", url)} +} + func (servers *Servers) RemoveInstituteAccess(url string) { servers.InstituteServers.Remove(url) } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 3981022..776bb72 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -1,11 +1,12 @@ package server import ( + "errors" "fmt" "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/types" "github.com/eduvpn/eduvpn-common/internal/util" + "github.com/eduvpn/eduvpn-common/types" ) // A secure internet server which has its own OAuth tokens @@ -23,6 +24,13 @@ type SecureInternetHomeServer struct { CurrentLocation string `json:"current_location"` } +func (servers *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) { + if !servers.HasSecureLocation() { + return nil, errors.New("No secure internet home server") + } + return &servers.SecureInternetHomeServer, nil +} + func (servers *Servers) RemoveSecureInternet() { // Empty out the struct servers.SecureInternetHomeServer = SecureInternetHomeServer{} diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py index 0192815..23851f3 100644 --- a/wrappers/python/eduvpn_common/loader.py +++ b/wrappers/python/eduvpn_common/loader.py @@ -91,6 +91,18 @@ def initialize_functions(lib): c_char_p, c_char_p, ], c_void_p + lib.AddInstituteAccess.argtypes, lib.AddInstituteAccess.restype = [ + c_char_p, + c_char_p, + ], c_void_p + lib.AddSecureInternetHomeServer.argtypes, lib.AddSecureInternetHomeServer.restype = [ + c_char_p, + c_char_p, + ], c_void_p + lib.AddCustomServer.argtypes, lib.AddCustomServer.restype = [ + c_char_p, + c_char_p, + ], c_void_p lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [ c_char_p, c_char_p, diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index 03e1045..382f356 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -121,6 +121,25 @@ class EduVPN(object): if remove_err: raise remove_err + def add_institute_access(self, url: str): + add_err = self.go_function(self.lib.AddInstituteAccess, url) + + if add_err: + raise add_err + + def add_secure_internet_home(self, org_id: str): + self.location_event = threading.Event() + add_err = self.go_function(self.lib.AddSecureInternetHomeServer, org_id) + + if add_err: + raise add_err + + def add_custom_server(self, url: str): + add_err = self.go_function(self.lib.AddCustomServer, url) + + if add_err: + raise add_err + def remove_institute_access(self, url: str): remove_err = self.go_function(self.lib.RemoveInstituteAccess, url) @@ -162,7 +181,6 @@ class EduVPN(object): def get_config_secure_internet( self, url: str, prefer_tcp: bool = False ) -> Tuple[str, str]: - self.location_event = threading.Event() return self.get_config(url, self.lib.GetConfigSecureInternet, prefer_tcp) def go_back(self) -> None: diff --git a/wrappers/python/main.py b/wrappers/python/main.py index 5604452..4881828 100644 --- a/wrappers/python/main.py +++ b/wrappers/python/main.py @@ -67,7 +67,7 @@ def setup_callbacks(_eduvpn: eduvpn.EduVPN) -> None: # The main entry point if __name__ == "__main__": - _eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "configs") + _eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "configs", "en") setup_callbacks(_eduvpn) # Register with the eduVPN-common library @@ -82,6 +82,7 @@ if __name__ == "__main__": # Get a Wireguard/OpenVPN config try: + _eduvpn.add_secure_internet("https://idp.geant.org") config, config_type = _eduvpn.get_config_secure_internet("https://idp.geant.org") print(f"Got a config with type: {config_type} and contents:\n{config}") except Exception as e: diff --git a/wrappers/python/tests.py b/wrappers/python/tests.py index d3caa38..555a0bb 100644 --- a/wrappers/python/tests.py +++ b/wrappers/python/tests.py @@ -31,6 +31,7 @@ class ConfigTests(unittest.TestCase): self.fail("No SERVER_URI environment variable given") # This can throw an exception + _eduvpn.add_custom_server(server_uri) _eduvpn.get_config_custom_server(server_uri) # Deregister -- cgit v1.2.3