From b7f86c83fffc221e654048e6e55c3f130c9fd308 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Mon, 8 May 2023 16:20:00 +0200 Subject: Client: Use a mutex for state transitions --- wrappers/python/eduvpn_common/main.py | 42 +++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 9 deletions(-) (limited to 'wrappers/python/eduvpn_common/main.py') diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index 64fa6ac..f3639e0 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -3,7 +3,17 @@ from enum import IntEnum from typing import Any, Callable, Iterator, Optional from eduvpn_common.loader import initialize_functions, load_lib -from eduvpn_common.types import ReadRxBytes, TokenGetter, TokenSetter, VPNStateChange, decode_res, encode_args +from eduvpn_common.types import ( + ReadRxBytes, + TokenGetter, + TokenSetter, + VPNStateChange, + decode_res, + encode_args, +) + +from eduvpn_common.event import EventHandler +from eduvpn_common.state import State class WrappedError(Exception): @@ -57,14 +67,20 @@ class EduVPN(object): self.version = version self.config_directory = config_directory self.jar = Jar(lambda x: self.go_function(self.lib.CookieCancel, x)) - self.callback = None self.token_setter = None self.token_getter = None + self.event_handler = EventHandler() # Load the library self.lib = load_lib() initialize_functions(self.lib) + def register_class_callbacks(self, _class): + self.event_handler.change_class_callbacks(_class, add=True) + + def deregister_class_callbacks(self, _class): + self.event_handler.change_class_callbacks(_class, add=False) + def go_cookie_function(self, func: Any, *args: Iterator) -> Any: cookie = self.lib.CookieNew() self.jar.add(cookie) @@ -95,7 +111,7 @@ class EduVPN(object): global global_object global_object = None - def register(self, handler: Optional[Callable] = None, debug: bool = False) -> None: + def register(self, debug: bool = False) -> None: """Register the Go shared library. This makes sure the FSM is initialized and that we can call Go functions @@ -106,7 +122,6 @@ class EduVPN(object): global global_object if global_object is not None: raise Exception("Already registered") - self.callback = handler global_object = self register_err = self.go_function( self.lib.Register, @@ -175,11 +190,17 @@ class EduVPN(object): if remove_err: forwardError(remove_err) - def set_state(self, state: int): + def set_state(self, state: State): state_err = self.go_function(self.lib.SetState, state) if state_err: forwardError(state_err) + def in_state(self, state: State) -> bool: + yes, state_err = self.go_function(self.lib.InState, state) + if state_err: + forwardError(state_err) + return yes + def get_config( self, _type: ServerType, identifier: str, prefer_tcp: bool = False ) -> str: @@ -257,7 +278,9 @@ class EduVPN(object): def set_token_handler(self, getter: Callable, setter: Callable) -> None: self.token_setter = setter self.token_getter = getter - handler_err = self.go_function(self.lib.SetTokenHandler, token_getter, token_setter) + handler_err = self.go_function( + self.lib.SetTokenHandler, token_getter, token_setter + ) if handler_err: forwardError(handler_err) @@ -309,6 +332,7 @@ class EduVPN(object): global_object: Optional[EduVPN] = None + @TokenSetter def token_setter(server: ctypes.c_char_p, tokens: ctypes.c_char_p): global global_object @@ -318,6 +342,7 @@ def token_setter(server: ctypes.c_char_p, tokens: ctypes.c_char_p): return 0 global_object.token_setter(server.decode(), tokens.decode()) + @TokenGetter def token_getter(server: ctypes.c_char_p, buf: ctypes.c_char_p, size: ctypes.c_size_t): global global_object @@ -332,6 +357,7 @@ def token_getter(server: ctypes.c_char_p, buf: ctypes.c_char_p, size: ctypes.c_s outbuf = ctypes.cast(buf, ctypes.POINTER(ctypes.c_char * size)) outbuf.contents.value = got.encode("utf-8") + @VPNStateChange def state_callback(old_state: int, new_state: int, data: str) -> int: """The internal callback that is passed to the Go library @@ -345,9 +371,7 @@ def state_callback(old_state: int, new_state: int, data: str) -> int: global global_object if global_object is None: return 0 - if global_object.callback is None: - return 0 - handled = global_object.callback(old_state, new_state, data.decode("utf-8")) + handled = global_object.event_handler.run(State(old_state), State(new_state), data.decode("utf-8")) if handled: return 1 return 0 -- cgit v1.2.3