From 4a4b3f0a1c008e35a4492b7fd05176d1822c7232 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Tue, 22 Nov 2022 16:14:06 +0100 Subject: FSM: Check unhandled transitions --- wrappers/python/eduvpn_common/event.py | 15 ++++++++++----- wrappers/python/eduvpn_common/main.py | 14 +++++++++----- wrappers/python/eduvpn_common/types.py | 2 +- 3 files changed, 20 insertions(+), 11 deletions(-) (limited to 'wrappers/python/eduvpn_common') diff --git a/wrappers/python/eduvpn_common/event.py b/wrappers/python/eduvpn_common/event.py index 1823130..4387222 100644 --- a/wrappers/python/eduvpn_common/event.py +++ b/wrappers/python/eduvpn_common/event.py @@ -147,7 +147,7 @@ class EventHandler(object): def run_state( self, state: int, other_state: int, state_type: StateType, data: str - ) -> None: + ) -> bool: """The function that runs the callback for a specific event :param state: int: The state of the event @@ -158,13 +158,14 @@ class EventHandler(object): :meta private: """ if (state, state_type) not in self.handlers: - return + return False for func in self.handlers[(state, state_type)]: func(other_state, data) + return True def run( self, old_state: int, new_state: int, data: Any, convert: bool = True - ) -> None: + ) -> bool: """Run a specific event. It converts the data and then runs the event for all state types @@ -180,5 +181,9 @@ class EventHandler(object): if convert: converted = convert_data(self.lib, new_state, data) self.run_state(old_state, new_state, StateType.LEAVE, converted) - self.run_state(new_state, old_state, StateType.ENTER, converted) - self.run_state(new_state, old_state, StateType.WAIT, converted) + # We decide handled based on enter transitions + handled = self.run_state(new_state, old_state, StateType.ENTER, converted) + # Only run wait transitions if the enter transition is handled + if handled: + self.run_state(new_state, old_state, StateType.WAIT, converted) + return handled diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index fff7e31..20d646f 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -1,4 +1,5 @@ import threading +from ctypes import c_int from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from eduvpn_common.discovery import DiscoOrganizations, DiscoServers, get_disco_organizations, get_disco_servers @@ -371,7 +372,7 @@ class EduVPN(object): """ return self.event_handler - def callback(self, old_state: State, new_state: State, data: Any) -> None: + def callback(self, old_state: State, new_state: State, data: Any) -> bool: """Run an event callback :param old_state: State: The previous state @@ -379,7 +380,7 @@ class EduVPN(object): :param data: Any: The data to pass to the event """ - self.event.run(old_state, new_state, data) + return self.event.run(old_state, new_state, data) def set_profile(self, profile_id: str) -> None: """Set the profile of the current server @@ -506,7 +507,7 @@ eduvpn_objects: Dict[str, EduVPN] = {} @VPNStateChange -def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> None: +def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> int: """The internal callback that is passed to the Go library :param name: bytes: The name of the client @@ -518,8 +519,11 @@ def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> No """ name_decoded = name.decode() if name_decoded not in eduvpn_objects: - return - eduvpn_objects[name_decoded].callback(State(old_state), State(new_state), data) + return 0 + handled = eduvpn_objects[name_decoded].callback(State(old_state), State(new_state), data) + if handled: + return 1 + return 0 def add_as_global_object(eduvpn: EduVPN) -> bool: diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py index a6eda43..07a02d3 100644 --- a/wrappers/python/eduvpn_common/types.py +++ b/wrappers/python/eduvpn_common/types.py @@ -165,7 +165,7 @@ class ConfigError(Structure): # The type for a Go state change callback -VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) +VPNStateChange = CFUNCTYPE(c_int, c_char_p, c_int, c_int, c_void_p) def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]: -- cgit v1.2.3