diff options
Diffstat (limited to 'wrappers/python/src/main.py')
| -rw-r--r-- | wrappers/python/src/main.py | 102 |
1 files changed, 89 insertions, 13 deletions
diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py index 4117a86..03757df 100644 --- a/wrappers/python/src/main.py +++ b/wrappers/python/src/main.py @@ -2,22 +2,47 @@ from . import lib, VPNStateChange, GetDataError, GetMultipleDataError, GetPtrStr from ctypes import * from enum import Enum from typing import Callable, Optional, Tuple +from functools import wraps +import threading class StateType(Enum): Enter = 1 Leave = 2 + Wait = 3 +EDUVPN_CALLBACK_PROPERTY = '_eduvpn_property_callback' + +# A state transition decorator for classes +# To use this, make sure to register the class with `register_class_callbacks` +def class_state_transition(state: str, state_type: StateType) -> Callable: + def wrapper(func): + setattr(func, EDUVPN_CALLBACK_PROPERTY, (state, state_type)) + return func + return wrapper class EventHandler(object): def __init__(self): self.handlers = {} + def remove_event(self, state: str, state_type: StateType, func: Callable): + for key, values in self.handlers.copy().items(): + if key == (state, state_type): + values.remove(func) + if not values: + del self.handlers[key] + else: + self.handlers[key] = values + + def add_event(self, state: str, state_type: StateType, func: Callable): + if (state, state_type) not in self.handlers: + self.handlers[(state, state_type)] = [] + self.handlers[(state, state_type)].append(func) + + # A decorator for standalone functions def on(self, state: str, state_type: StateType) -> Callable: def wrapped_f(func): - if (state, state_type) not in self.handlers: - self.handlers[(state, state_type)] = [] - self.handlers[(state, state_type)].append(func) + self.add_event(state, state_type, func) return func return wrapped_f @@ -33,8 +58,12 @@ class EventHandler(object): def run(self, old_state: str, new_state: str, data: str) -> None: if old_state == new_state: return + + # First run leave transitions, then enter + # The state is done when the wait event finishes self.run_state(old_state, new_state, StateType.Leave, data) self.run_state(new_state, old_state, StateType.Enter, data) + self.run_state(new_state, old_state, StateType.Wait, data) # Registers the python app with the Go code @@ -153,6 +182,16 @@ class EduVPN(object): self.config_directory = config_directory register_callback(self) + # Callbacks that need to wait for specific events + + # The ask profile callback needs to wait for the UI thread to select a profile + # This is stored in the profile_event + self.profile_event: Optional[threading.Event] = None + @self.event.on("Ask_Profile", StateType.Wait) + def wait_profile_event(old_state: str, profiles: str): + if self.profile_event: + self.profile_event.wait() + def cancel_oauth(self) -> None: cancel_oauth_err = CancelOAuth(self.name) @@ -189,31 +228,35 @@ class EduVPN(object): return organizations - def get_config_institute_access( - self, url: str, force_tcp: bool = False - ) -> Tuple[str, str]: + def get_config(self, url: str, is_secure_internet: bool = False, force_tcp: bool = False): + # Because it could be the case that a profile callback is started, store a threading event + # In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set + # The event is set in self.set_profile + self.profile_event = threading.Event() config, config_type, config_err = GetConnectConfig( - self.name, url, False, force_tcp + self.name, url, is_secure_internet, force_tcp ) if config_err: raise Exception(config_err) + self.profile_event = None + return config, config_type + def get_config_institute_access( + self, url: str, force_tcp: bool = False + ) -> Tuple[str, str]: + return self.get_config(url, False, force_tcp) + def get_config_secure_internet( self, url: str, force_tcp: bool = False ) -> Tuple[str, str]: - config, config_type, config_err = GetConnectConfig( - self.name, url, True, force_tcp - ) + return self.get_config(url, True, force_tcp) - if config_err: - raise Exception(config_err) def set_connected(self) -> None: connect_err = SetConnected(self.name) - return config, config_type if connect_err: raise Exception(connect_err) @@ -243,6 +286,33 @@ class EduVPN(object): if search_err: raise Exception(search_err) + def change_class_callbacks(self, cls, add=True) -> None: + # Loop over method names + for method_name in dir(cls): + + try: + # Get the method + method = getattr(cls, method_name) + except: + # Unable to get a value, go to the next + continue + + # If it has a callback defined, add it to the events + method_value = getattr(method, EDUVPN_CALLBACK_PROPERTY, None) + if method_value: + state, state_type = method_value + + if add: + self.event.add_event(state, state_type, method) + else: + self.event.remove_event(state, state_type, method) + + def remove_class_callbacks(self, cls) -> None: + self.change_class_callbacks(cls, add=False) + + def register_class_callbacks(self, cls) -> None: + self.change_class_callbacks(cls) + @property def event(self) -> EventHandler: return self.event_handler @@ -251,7 +321,13 @@ class EduVPN(object): self.event.run(old_state, new_state, data) def set_profile(self, profile_id: str) -> None: + # Set the profile id profile_err = SetProfileID(self.name, profile_id) if profile_err: raise Exception(profile_err) + + # If there is a profile event, set it so that the wait callback finishes + # And so that the Go code can move to the next state + if self.profile_event: + self.profile_event.set() |
