diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-11-22 16:14:06 +0100 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-11-23 16:16:09 +0100 |
| commit | 4a4b3f0a1c008e35a4492b7fd05176d1822c7232 (patch) | |
| tree | 287ac69b6f89524282d4e2cbc85c6d8030285c88 | |
| parent | ea07a6d7b2df9b09d8e4c796b2416a60ba90144a (diff) | |
FSM: Check unhandled transitions
| -rw-r--r-- | client/client.go | 10 | ||||
| -rw-r--r-- | client/client_test.go | 18 | ||||
| -rw-r--r-- | client/fsm.go | 2 | ||||
| -rw-r--r-- | client/server.go | 13 | ||||
| -rw-r--r-- | cmd/cli/main.go | 3 | ||||
| -rw-r--r-- | exports/exports.go | 17 | ||||
| -rw-r--r-- | internal/fsm/fsm.go | 21 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/event.py | 15 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/main.py | 14 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/types.py | 2 |
10 files changed, 77 insertions, 38 deletions
diff --git a/client/client.go b/client/client.go index 103136b..63e4f91 100644 --- a/client/client.go +++ b/client/client.go @@ -73,7 +73,7 @@ func (client *Client) Register( name string, directory string, language string, - stateCallback func(FSMStateID, FSMStateID, interface{}), + stateCallback func(FSMStateID, FSMStateID, interface{}) bool, debug bool, ) error { errorMessage := "failed to register with the GO library" @@ -155,11 +155,15 @@ func (client *Client) Deregister() { // askProfile asks the user for a profile by moving the FSM to the ASK_PROFILE state. func (client *Client) askProfile(chosenServer server.Server) error { + errorMessage := "failed asking for profiles" profiles, profilesErr := server.GetValidProfiles(chosenServer, client.SupportsWireguard) if profilesErr != nil { - return types.NewWrappedError("failed asking for profiles", profilesErr) + return types.NewWrappedError(errorMessage, profilesErr) + } + goTransitionErr := client.FSM.GoTransitionRequired(STATE_ASK_PROFILE, profiles) + if goTransitionErr != nil { + return types.NewWrappedError(errorMessage, goTransitionErr) } - client.FSM.GoTransitionWithData(STATE_ASK_PROFILE, profiles) return nil } diff --git a/client/client_test.go b/client/client_test.go index f386c3c..adf5c75 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -81,8 +81,9 @@ func Test_server(t *testing.T) { "org.letsconnect-vpn.app.linux", "configstest", "en", - func(old FSMStateID, new FSMStateID, data interface{}) { + func(old FSMStateID, new FSMStateID, data interface{}) bool { stateCallback(t, old, new, data, state) + return true }, false, ) @@ -113,7 +114,7 @@ func test_connect_oauth_parameter( "org.letsconnect-vpn.app.linux", configDirectory, "en", - func(oldState FSMStateID, newState FSMStateID, data interface{}) { + func(oldState FSMStateID, newState FSMStateID, data interface{}) bool { if newState == STATE_OAUTH_STARTED { server, serverErr := state.Servers.GetCustomServer(serverURI) if serverErr != nil { @@ -142,6 +143,7 @@ func test_connect_oauth_parameter( } }() } + return true }, false, ) @@ -219,8 +221,9 @@ func Test_token_expired(t *testing.T) { "org.letsconnect-vpn.app.linux", "configsexpired", "en", - func(old FSMStateID, new FSMStateID, data interface{}) { + func(old FSMStateID, new FSMStateID, data interface{}) bool { stateCallback(t, old, new, data, state) + return true }, false, ) @@ -279,8 +282,9 @@ func Test_token_invalid(t *testing.T) { "org.letsconnect-vpn.app.linux", "configsinvalid", "en", - func(old FSMStateID, new FSMStateID, data interface{}) { + func(old FSMStateID, new FSMStateID, data interface{}) bool { stateCallback(t, old, new, data, state) + return true }, false, ) @@ -336,8 +340,9 @@ func Test_invalid_profile_corrected(t *testing.T) { "org.letsconnect-vpn.app.linux", "configscancelprofile", "en", - func(old FSMStateID, new FSMStateID, data interface{}) { + func(old FSMStateID, new FSMStateID, data interface{}) bool { stateCallback(t, old, new, data, state) + return true }, false, ) @@ -393,8 +398,9 @@ func Test_prefer_tcp(t *testing.T) { "org.letsconnect-vpn.app.linux", "configsprefertcp", "en", - func(old FSMStateID, new FSMStateID, data interface{}) { + func(old FSMStateID, new FSMStateID, data interface{}) bool { stateCallback(t, old, new, data, state) + return true }, false, ) diff --git a/client/fsm.go b/client/fsm.go index 767ceaa..c93a9a8 100644 --- a/client/fsm.go +++ b/client/fsm.go @@ -96,7 +96,7 @@ func GetStateName(s FSMStateID) string { } func newFSM( - callback func(FSMStateID, FSMStateID, interface{}), + callback func(FSMStateID, FSMStateID, interface{}) bool, directory string, debug bool, ) fsm.FSM { diff --git a/client/server.go b/client/server.go index 3bcecc6..d22dc65 100644 --- a/client/server.go +++ b/client/server.go @@ -479,16 +479,20 @@ func (client *Client) GetConfigCustomServer(url string, preferTCP bool) (string, // askSecureLocation asks the user to choose a Secure Internet location by moving the FSM to the STATE_ASK_LOCATION state. func (client *Client) askSecureLocation() error { + errorMessage := "failed settings secure location" locations := client.Discovery.GetSecureLocationList() // Ask for the location in the callback - client.FSM.GoTransitionWithData(STATE_ASK_LOCATION, locations) + goTransitionErr := client.FSM.GoTransitionRequired(STATE_ASK_LOCATION, locations) + if goTransitionErr != nil { + return types.NewWrappedError(errorMessage, goTransitionErr) + } // The state has changed, meaning setting the secure location was not successful if client.FSM.Current != STATE_ASK_LOCATION { // TODO: maybe a custom type for this errors.new? return types.NewWrappedError( - "failed setting secure location", + errorMessage, errors.New("failed loading secure location"), ) } @@ -579,7 +583,10 @@ func (client *Client) ensureLogin(chosenServer server.Server) error { if server.NeedsRelogin(chosenServer) { url, urlErr := server.GetOAuthURL(chosenServer, client.Name) - client.FSM.GoTransitionWithData(STATE_OAUTH_STARTED, url) + goTransitionErr := client.FSM.GoTransitionRequired(STATE_OAUTH_STARTED, url) + if goTransitionErr != nil { + return types.NewWrappedError(errorMessage, goTransitionErr) + } if urlErr != nil { client.goBackInternal() diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 97d0f39..96ca42f 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -125,8 +125,9 @@ func printConfig(url string, serverType ServerTypes) { "org.eduvpn.app.linux", "configs", "en", - func(old client.FSMStateID, new client.FSMStateID, data interface{}) { + func(old client.FSMStateID, new client.FSMStateID, data interface{}) bool { stateCallback(state, old, new, data) + return true }, true, ) diff --git a/exports/exports.go b/exports/exports.go index 05f2462..8e418d4 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -4,11 +4,11 @@ package main #include <stdlib.h> #include "error.h" -typedef void (*PythonCB)(const char* name, int oldstate, int newstate, void* data); +typedef int (*PythonCB)(const char* name, int oldstate, int newstate, void* data); -static void call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data) +static int call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data) { - callback(name, oldstate, newstate, data); + return callback(name, oldstate, newstate, data); } */ import "C" @@ -61,18 +61,19 @@ func StateCallback( old_state client.FSMStateID, new_state client.FSMStateID, data interface{}, -) { +) bool { P_StateCallback, exists := P_StateCallbacks[name] if !exists || P_StateCallback == nil { - return + return false } name_c := C.CString(name) oldState_c := C.int(old_state) newState_c := C.int(new_state) data_c := GetStateData(state, new_state, data) - C.call_callback(P_StateCallback, name_c, oldState_c, newState_c, data_c) + handled := C.call_callback(P_StateCallback, name_c, oldState_c, newState_c, data_c) C.free(unsafe.Pointer(name_c)) // data_c gets freed by the wrapper + return handled == C.int(1) } func GetVPNState(name string) (*client.Client, error) { @@ -110,8 +111,8 @@ func Register( nameStr, C.GoString(config_directory), C.GoString(language), - func(old client.FSMStateID, new client.FSMStateID, data interface{}) { - StateCallback(state, nameStr, old, new, data) + func(old client.FSMStateID, new client.FSMStateID, data interface{}) bool { + return StateCallback(state, nameStr, old, new, data) }, debug != 0, ) diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go index 198d51a..1b45ce8 100644 --- a/internal/fsm/fsm.go +++ b/internal/fsm/fsm.go @@ -6,6 +6,7 @@ import ( "os/exec" "path" "sort" + "github.com/eduvpn/eduvpn-common/types" ) type ( @@ -47,7 +48,7 @@ type FSM struct { // Info to be passed from the parent state Name string - StateCallback func(FSMStateID, FSMStateID, interface{}) + StateCallback func(FSMStateID, FSMStateID, interface{}) bool Directory string Debug bool GetName func(FSMStateID) string @@ -56,7 +57,7 @@ type FSM struct { func (fsm *FSM) Init( current FSMStateID, states map[FSMStateID]FSMState, - callback func(FSMStateID, FSMStateID, interface{}), + callback func(FSMStateID, FSMStateID, interface{}) bool, directory string, nameGen func(FSMStateID) string, debug bool, @@ -110,9 +111,18 @@ func (fsm *FSM) GoBack() { fsm.GoTransition(fsm.States[fsm.Current].BackState) } +func (fsm *FSM) GoTransitionRequired(newState FSMStateID, data interface{}) error { + oldState := fsm.Current + if !fsm.GoTransitionWithData(newState, data) { + return types.NewWrappedError("failed required transition", fmt.Errorf("required transition not handled, from: %s -> to: %s", fsm.GetName(oldState), fsm.GetName(newState))) + } + return nil +} + func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool { ok := fsm.HasTransition(newState) + handled := false if ok { oldState := fsm.Current fsm.Current = newState @@ -120,14 +130,15 @@ func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool fsm.writeGraph() } - fsm.StateCallback(oldState, newState, data) + handled = fsm.StateCallback(oldState, newState, data) } - return ok + return handled } func (fsm *FSM) GoTransition(newState FSMStateID) bool { - return fsm.GoTransitionWithData(newState, "{}") + // No data means the callback is never required + return fsm.GoTransitionWithData(newState, "") } func (fsm *FSM) generateMermaidGraph() string { diff --git a/wrappers/python/eduvpn_common/event.py b/wrappers/python/eduvpn_common/event.py index 1823130..4387222 100644 --- a/wrappers/python/eduvpn_common/event.py +++ b/wrappers/python/eduvpn_common/event.py @@ -147,7 +147,7 @@ class EventHandler(object): def run_state( self, state: int, other_state: int, state_type: StateType, data: str - ) -> None: + ) -> bool: """The function that runs the callback for a specific event :param state: int: The state of the event @@ -158,13 +158,14 @@ class EventHandler(object): :meta private: """ if (state, state_type) not in self.handlers: - return + return False for func in self.handlers[(state, state_type)]: func(other_state, data) + return True def run( self, old_state: int, new_state: int, data: Any, convert: bool = True - ) -> None: + ) -> bool: """Run a specific event. It converts the data and then runs the event for all state types @@ -180,5 +181,9 @@ class EventHandler(object): if convert: converted = convert_data(self.lib, new_state, data) self.run_state(old_state, new_state, StateType.LEAVE, converted) - self.run_state(new_state, old_state, StateType.ENTER, converted) - self.run_state(new_state, old_state, StateType.WAIT, converted) + # We decide handled based on enter transitions + handled = self.run_state(new_state, old_state, StateType.ENTER, converted) + # Only run wait transitions if the enter transition is handled + if handled: + self.run_state(new_state, old_state, StateType.WAIT, converted) + return handled diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index fff7e31..20d646f 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -1,4 +1,5 @@ import threading +from ctypes import c_int from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from eduvpn_common.discovery import DiscoOrganizations, DiscoServers, get_disco_organizations, get_disco_servers @@ -371,7 +372,7 @@ class EduVPN(object): """ return self.event_handler - def callback(self, old_state: State, new_state: State, data: Any) -> None: + def callback(self, old_state: State, new_state: State, data: Any) -> bool: """Run an event callback :param old_state: State: The previous state @@ -379,7 +380,7 @@ class EduVPN(object): :param data: Any: The data to pass to the event """ - self.event.run(old_state, new_state, data) + return self.event.run(old_state, new_state, data) def set_profile(self, profile_id: str) -> None: """Set the profile of the current server @@ -506,7 +507,7 @@ eduvpn_objects: Dict[str, EduVPN] = {} @VPNStateChange -def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> None: +def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> int: """The internal callback that is passed to the Go library :param name: bytes: The name of the client @@ -518,8 +519,11 @@ def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> No """ name_decoded = name.decode() if name_decoded not in eduvpn_objects: - return - eduvpn_objects[name_decoded].callback(State(old_state), State(new_state), data) + return 0 + handled = eduvpn_objects[name_decoded].callback(State(old_state), State(new_state), data) + if handled: + return 1 + return 0 def add_as_global_object(eduvpn: EduVPN) -> bool: diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py index a6eda43..07a02d3 100644 --- a/wrappers/python/eduvpn_common/types.py +++ b/wrappers/python/eduvpn_common/types.py @@ -165,7 +165,7 @@ class ConfigError(Structure): # The type for a Go state change callback -VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) +VPNStateChange = CFUNCTYPE(c_int, c_char_p, c_int, c_int, c_void_p) def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]: |
