From 1ef27cc47ad56a2c66aaa40e398a0063be2573d4 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Mon, 9 May 2022 14:18:23 +0200 Subject: FSM/State: Profile correctness and connect name change Also add a force tcp flag --- cmd/cli/main.go | 4 +-- exports/exports.go | 13 +++++++-- internal/api.go | 8 ++++++ internal/fsm.go | 10 +++---- internal/server.go | 22 +++++++++++--- state.go | 26 +++++++++++------ state_test.go | 63 ++++++++++++++++++++++++++++++++++------- wrappers/python/main.py | 2 +- wrappers/python/src/__init__.py | 6 ++-- wrappers/python/src/main.py | 13 +++++---- 10 files changed, 124 insertions(+), 43 deletions(-) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 5bdbd8d..10a4b29 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -94,9 +94,9 @@ func getConfig(url string, isInstitute bool) (string, error) { defer state.Deregister() if isInstitute { - return state.ConnectInstituteAccess(url) + return state.GetConfigInstituteAccess(url, false) } - return state.ConnectSecureInternet(url) + return state.GetConfigSecureInternet(url, false) } type ServerDiscoEntry struct { diff --git a/exports/exports.go b/exports/exports.go index 434bfbd..919f5d7 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -96,14 +96,21 @@ func CancelOAuth(name *C.char) *C.char { return C.CString(cancelErrString) } -//export Connect -func Connect(name *C.char, url *C.char) (*C.char, *C.char) { +//export GetConnectConfig +func GetConnectConfig(name *C.char, url *C.char, isSecureInternet C.int, forceTCP C.int) (*C.char, *C.char) { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { return nil, C.CString(ErrorToString(stateErr)) } - config, configErr := state.ConnectInstituteAccess(C.GoString(url)) + var config string + var configErr error + forceTCPBool := forceTCP == 1 + if isSecureInternet == 1 { + config, configErr = state.GetConfigSecureInternet(C.GoString(url), forceTCPBool) + } else { + config, configErr = state.GetConfigInstituteAccess(C.GoString(url), forceTCPBool) + } return C.CString(config), C.CString(ErrorToString(configErr)) } diff --git a/internal/api.go b/internal/api.go index a987f00..b615976 100644 --- a/internal/api.go +++ b/internal/api.go @@ -24,8 +24,12 @@ func apiAuthorized(server Server, method string, endpoint string, opts *HTTPOpti url := base.Endpoints.API.V3.API + endpoint // Ensure we have valid tokens + stateBefore := base.FSM.Current oauthErr := EnsureTokens(server) + // we reset the state so that we go from the authorized state to the state we want + base.FSM.Current = stateBefore + if oauthErr != nil { return nil, nil, oauthErr } @@ -83,7 +87,11 @@ func APIInfo(server Server) error { if baseErr != nil { return &APIInfoError{Err: baseErr} } + + // Store the profiles and make sure that the current profile is not overwritten + previousProfile := base.Profiles.Current base.Profiles = structure + base.Profiles.Current = previousProfile base.ProfilesRaw = string(body) return nil } diff --git a/internal/fsm.go b/internal/fsm.go index 0b9ad1e..b52b463 100644 --- a/internal/fsm.go +++ b/internal/fsm.go @@ -106,12 +106,12 @@ func (fsm *FSM) Init(name string, callback func(string, string, string), logger DEREGISTERED: {{NO_SERVER, "Client registers"}}, NO_SERVER: {{CHOSEN_SERVER, "User chooses a server"}}, CHOSEN_SERVER: {{AUTHORIZED, "Found tokens in config"}, {OAUTH_STARTED, "No tokens found in config"}}, - OAUTH_STARTED: {{AUTHORIZED, "User authorizes with browser"}, {CHOSEN_SERVER, "Cancel OAuth"}}, + OAUTH_STARTED: {{AUTHORIZED, "User authorizes with browser"}, {NO_SERVER, "Cancel or Error"}}, AUTHORIZED: {{OAUTH_STARTED, "Re-authorize with OAuth"}, {REQUEST_CONFIG, "Client requests a config"}}, - REQUEST_CONFIG: {{ASK_PROFILE, "Multiple profiles found"}, {HAS_CONFIG, "Success, only one profile"}}, - ASK_PROFILE: {{HAS_CONFIG, "User chooses profile and success"}}, - HAS_CONFIG: {{CONNECTED, "OS reports connected"}}, - CONNECTED: {{AUTHORIZED, "OS reports disconnected"}}, + REQUEST_CONFIG: {{ASK_PROFILE, "Multiple profiles found and no profile chosen"}, {HAS_CONFIG, "Only one profile or profile already chosen"}, {NO_SERVER, "Cancel or Error"}, {OAUTH_STARTED, "Re-authorize"}}, + ASK_PROFILE: {{HAS_CONFIG, "User chooses profile"}, {NO_SERVER, "Done but no profile selected"}}, + HAS_CONFIG: {{CONNECTED, "OS reports connected"}, {REQUEST_CONFIG, "User chooses a new profile"}, {NO_SERVER, "User wants to choose a new server"}}, + CONNECTED: {{HAS_CONFIG, "OS reports disconnected"}}, } fsm.Current = DEREGISTERED fsm.Name = name diff --git a/internal/server.go b/internal/server.go index 9d27907..4452c6d 100644 --- a/internal/server.go +++ b/internal/server.go @@ -340,15 +340,29 @@ func GetConfig(server Server) (string, error) { if !base.FSM.InState(REQUEST_CONFIG) { return "", &FSMWrongStateError{Got: base.FSM.Current, Want: REQUEST_CONFIG} } - infoErr := APIInfo(server) + // Get new profiles using the info call + // This does not override the current profile + infoErr := APIInfo(server) if infoErr != nil { return "", &ServerGetConfigError{Err: infoErr} } - // Set the current profile if there is only one profile - if len(base.Profiles.Info.ProfileList) == 1 { - base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID + // If there was a profile chosen and it doesn't exist anymore, reset it + if base.Profiles.Current != "" { + _, existsProfileErr := getCurrentProfile(server) + if existsProfileErr != nil { + base.Logger.Log(LOG_INFO, fmt.Sprintf("Profile %s no longer exists, resetting the profile", base.Profiles.Current)) + base.Profiles.Current = "" + } + } + + // Set the current profile if there is only one profile or profile is already selected + if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" { + // Set the first profile if none is selected + if base.Profiles.Current == "" { + base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID + } return getConfigWithProfile(server) } diff --git a/state.go b/state.go index 14b8eb6..7f383d9 100644 --- a/state.go +++ b/state.go @@ -100,11 +100,17 @@ func (state *VPNState) chooseServer(url string, isSecureInternet bool) (internal return server, nil } -func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (string, error) { +func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, forceTCP bool) (string, error) { + // FIXME: Do something with force tcp if state.FSM.InState(internal.DEREGISTERED) { return "", &StateFSMNotRegisteredError{} } + // Go to no server if possible, else return an error + if !state.FSM.InState(internal.NO_SERVER) && !state.FSM.GoTransition(internal.NO_SERVER) { + return "", &internal.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: internal.NO_SERVER} + } + // Make sure the server is chosen server, serverErr := state.chooseServer(url, isSecureInternet) @@ -118,8 +124,8 @@ func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (st if loginErr != nil { // We are possibly in oauth started - // So go to chosen server - state.FSM.GoTransition(internal.CHOSEN_SERVER) + // So go to no server + state.FSM.GoTransition(internal.NO_SERVER) return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: loginErr} } } else { // OAuth was valid, ensure we are in the authorized state @@ -131,6 +137,8 @@ func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (st config, configErr := internal.GetConfig(server) if configErr != nil { + // Go back to no server if possible + state.FSM.GoTransition(internal.NO_SERVER) return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: configErr} } else { state.FSM.GoTransition(internal.HAS_CONFIG) @@ -139,12 +147,12 @@ func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (st return config, nil } -func (state *VPNState) ConnectInstituteAccess(url string) (string, error) { - return state.connectWithOptions(url, false) +func (state *VPNState) GetConfigInstituteAccess(url string, forceTCP bool) (string, error) { + return state.getConfigWithOptions(url, false, forceTCP) } -func (state *VPNState) ConnectSecureInternet(url string) (string, error) { - return state.connectWithOptions(url, true) +func (state *VPNState) GetConfigSecureInternet(url string, forceTCP bool) (string, error) { + return state.getConfigWithOptions(url, true, forceTCP) } func (state *VPNState) GetDiscoOrganizations() (string, error) { @@ -204,7 +212,7 @@ type StateSetProfileError struct { } func (e *StateSetProfileError) Error() string { - return fmt.Sprintf("failed to set profile ID %s with error %v", e.ProfileID, e.Err) + return fmt.Sprintf("failed to set profile ID: %s with error: %v", e.ProfileID, e.Err) } type StateRegisterError struct { @@ -212,7 +220,7 @@ type StateRegisterError struct { } func (e *StateRegisterError) Error() string { - return fmt.Sprintf("failed to register with error %v", e.Err) + return fmt.Sprintf("failed to register with error: %v", e.Err) } type StateFSMNotRegisteredError struct{} diff --git a/state_test.go b/state_test.go index 84e648d..21b3abd 100644 --- a/state_test.go +++ b/state_test.go @@ -68,7 +68,7 @@ func Test_server(t *testing.T) { stateCallback(t, old, new, data, state) }, false) - _, configErr := state.ConnectInstituteAccess(serverURI) + _, configErr := state.GetConfigInstituteAccess(serverURI, false) if configErr != nil { t.Fatalf("Connect error: %v", configErr) @@ -91,7 +91,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters internal.URLParameter } }, false) - _, configErr := state.ConnectInstituteAccess(serverURI) + _, configErr := state.GetConfigInstituteAccess(serverURI, false) var stateErr *StateConnectError var loginErr *internal.OAuthLoginError @@ -171,7 +171,7 @@ func Test_token_expired(t *testing.T) { stateCallback(t, old, new, data, state) }, false) - _, configErr := state.ConnectInstituteAccess(serverURI) + _, configErr := state.GetConfigInstituteAccess(serverURI, false) if configErr != nil { t.Fatalf("Connect error before expired: %v", configErr) @@ -219,16 +219,16 @@ func Test_token_invalid(t *testing.T) { stateCallback(t, old, new, data, state) }, false) - _, configErr := state.ConnectInstituteAccess(serverURI) + _, configErr := state.GetConfigInstituteAccess(serverURI, false) if configErr != nil { t.Fatalf("Connect error before invalid: %v", configErr) } - // Fake connect and then back to authorized so that we can re-authorize - // Going to authorized fakes a disconnect - state.FSM.GoTransition(internal.CONNECTED) - state.FSM.GoTransition(internal.AUTHORIZED) + // Go to request_config so we can re-authorize + // This is needed as the only actual authenticated requests we do in request_config (for profiles) and /connect + // /disconnect is best effort so this does not need re-auth + state.FSM.GoTransition(internal.REQUEST_CONFIG) dummy_value := "37" @@ -243,10 +243,10 @@ func Test_token_invalid(t *testing.T) { oauth.Token.Access = dummy_value oauth.Token.Refresh = dummy_value - infoErr := internal.APIInfo(server) + _, configErr = state.GetConfigInstituteAccess(serverURI, false) - if infoErr != nil { - t.Fatalf("Info error after invalid: %v", infoErr) + if configErr != nil { + t.Fatalf("Connect error after invalid: %v", configErr) } if oauth.Token.Access == dummy_value { @@ -257,3 +257,44 @@ func Test_token_invalid(t *testing.T) { t.Errorf("Refresh token is equal to dummy value: %s", dummy_value) } } + +// Test if an invalid profile will be corrected +func Test_invalid_profile_corrected(t *testing.T) { + serverURI := getServerURI(t) + state := &VPNState{} + + ensureLocalWellKnown() + + state.Register("org.eduvpn.app.linux", "configscancelprofile", func(old string, new string, data string) { + stateCallback(t, old, new, data, state) + }, false) + + _, configErr := state.GetConfigInstituteAccess(serverURI, false) + + if configErr != nil { + t.Fatalf("First connect error: %v", configErr) + } + + server, serverErr := state.Servers.GetCurrentServer() + if serverErr != nil { + t.Fatalf("No server found") + } + + base, baseErr := server.GetBase() + if baseErr != nil { + t.Fatalf("No base found") + } + + previousProfile := base.Profiles.Current + base.Profiles.Current = "IDONOTEXIST" + + _, configErr = state.GetConfigInstituteAccess(serverURI, false) + + if configErr != nil { + t.Fatalf("Second connect error: %v", configErr) + } + + if base.Profiles.Current != previousProfile { + t.Fatalf("Profiles do no match: current %s and previous %s", base.Profiles.Current, previousProfile) + } +} diff --git a/wrappers/python/main.py b/wrappers/python/main.py index be9ab6c..1c1afd7 100644 --- a/wrappers/python/main.py +++ b/wrappers/python/main.py @@ -24,7 +24,7 @@ if not success: print(_eduvpn.get_disco()) -config, error = _eduvpn.connect("https://eduvpn.jwijenbergh.com") +config, error = _eduvpn.get_config_institute_access("https://eduvpn.jwijenbergh.com") if error: print("Got connect error", error) diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py index e417371..c028f09 100644 --- a/wrappers/python/src/__init__.py +++ b/wrappers/python/src/__init__.py @@ -33,15 +33,15 @@ class DataError(Structure): VPNStateChange = CFUNCTYPE(None, c_char_p, c_char_p, c_char_p) # Exposed functions -lib.Connect.argtypes, lib.Connect.restype = [c_char_p, c_char_p], DataError +# We have to use c_void_p instead of c_char_p to free it properly +# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error +lib.GetConnectConfig.argtypes, lib.GetConnectConfig.restype = [c_char_p, c_char_p, c_int, c_int], DataError lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], c_void_p lib.Register.argtypes, lib.Register.restype = [c_char_p, c_char_p, VPNStateChange, c_int], c_void_p lib.GetOrganizationsList.argtypes, lib.GetOrganizationsList.restype = [c_char_p], DataError lib.GetServersList.argtypes, lib.GetServersList.restype = [c_char_p], DataError lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p -# We have to use c_void_p instead of c_char_p to free it properly -# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error lib.SetConnected.argtypes, lib.SetConnected.restype = [c_char_p], c_void_p lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p], c_void_p lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py index 9c2fb41..2b346e3 100644 --- a/wrappers/python/src/main.py +++ b/wrappers/python/src/main.py @@ -36,11 +36,10 @@ def GetDiscoServers(name): organizations, organizationsErr = GetDataError(lib.GetOrganizationsList(name_bytes)) return servers, serversErr, organizations, organizationsErr - -def Connect(name, url): +def GetConnectConfig(name, url, is_secure_internet, force_tcp): name_bytes = name.encode("utf-8") url_bytes = url.encode("utf-8") - data_error = lib.Connect(name_bytes, url_bytes) + data_error = lib.GetConnectConfig(name_bytes, url_bytes, is_secure_internet, force_tcp) return GetDataError(data_error) def SetConnected(name): @@ -95,8 +94,12 @@ class EduVPN(object): def get_disco(self): return GetDiscoServers(self.name) - def connect(self, url): - return Connect(self.name, url) + def get_config_institute_access(self, url, force_tcp=False): + return GetConnectConfig(self.name, url, False, force_tcp) + + def get_config_secure_internet(self, url, force_tcp=False): + return GetConnectConfig(self.name, url, True, force_tcp) + def set_disconnected(self): return SetDisconnected(self.name) -- cgit v1.2.3