diff options
| -rw-r--r-- | exports/exports.go | 26 | ||||
| -rw-r--r-- | wrappers/python/src/__init__.py | 37 | ||||
| -rw-r--r-- | wrappers/python/src/event.py | 86 | ||||
| -rw-r--r-- | wrappers/python/src/main.py | 256 |
4 files changed, 180 insertions, 225 deletions
diff --git a/exports/exports.go b/exports/exports.go index 1081754..567e189 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -3,12 +3,12 @@ package main /* #include <stdlib.h> -typedef void (*PythonCB)(const char* oldstate, const char* newstate, const char* data); +typedef void (*PythonCB)(const char* name, const char* oldstate, const char* newstate, const char* data); __attribute__((weak)) -void call_callback(PythonCB callback, const char* oldstate, const char* newstate, const char* data) +void call_callback(PythonCB callback, const char *name, const char* oldstate, const char* newstate, const char* data) { - callback(oldstate, newstate, data); + callback(name, oldstate, newstate, data); } */ import "C" @@ -21,18 +21,21 @@ import ( "github.com/jwijenbergh/eduvpn-common" ) -var P_StateCallback C.PythonCB +var P_StateCallbacks map[string]C.PythonCB var VPNStates map[string]*eduvpn.VPNState -func StateCallback(old_state string, new_state string, data string) { - if P_StateCallback == nil { +func StateCallback(name string, old_state string, new_state string, data string) { + P_StateCallback, exists := P_StateCallbacks[name] + if !exists || P_StateCallback == nil { return } + name_c := C.CString(name) oldState_c := C.CString(old_state) newState_c := C.CString(new_state) data_c := C.CString(data) - C.call_callback(P_StateCallback, oldState_c, newState_c, data_c) + C.call_callback(P_StateCallback, name_c, oldState_c, newState_c, data_c) + C.free(unsafe.Pointer(name_c)) C.free(unsafe.Pointer(oldState_c)) C.free(unsafe.Pointer(newState_c)) C.free(unsafe.Pointer(data_c)) @@ -58,9 +61,14 @@ func Register(name *C.char, config_directory *C.char, stateCallback C.PythonCB, if VPNStates == nil { VPNStates = make(map[string]*eduvpn.VPNState) } + if P_StateCallbacks == nil { + P_StateCallbacks = make(map[string]C.PythonCB) + } VPNStates[nameStr] = state - P_StateCallback = stateCallback - registerErr := state.Register(nameStr, C.GoString(config_directory), StateCallback, debug != 0) + P_StateCallbacks[nameStr] = stateCallback + registerErr := state.Register(nameStr, C.GoString(config_directory), func(old string, new string, data string) { + StateCallback(nameStr, old, new, data) + }, debug != 0) if registerErr != nil { delete(VPNStates, nameStr) diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py index d260916..8129495 100644 --- a/wrappers/python/src/__init__.py +++ b/wrappers/python/src/__init__.py @@ -35,7 +35,7 @@ class MultipleDataError(Structure): _fields_ = [("data", c_void_p), ("other_data", c_void_p), ("error", c_void_p)] -VPNStateChange = CFUNCTYPE(None, c_char_p, c_char_p, c_char_p) +VPNStateChange = CFUNCTYPE(None, c_char_p, c_char_p, c_char_p, c_char_p) # Exposed functions # We have to use c_void_p instead of c_char_p to free it properly @@ -67,7 +67,19 @@ lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None -def GetPtrString(ptr: c_void_p) -> str: +def encode_args(args, types): + for arg, t in zip(args, types): + # c_char_p needs the str to be encoded to bytes + if t is c_char_p: + arg = arg.encode("utf-8") + yield arg + + +def decode_res(t): + return decode_map.get(t, lambda x: x) + + +def get_ptr_string(ptr: c_void_p) -> str: if ptr: string = cast(ptr, c_char_p).value lib.FreeString(ptr) @@ -76,16 +88,23 @@ def GetPtrString(ptr: c_void_p) -> str: return "" -def GetDataError(data_error: DataError) -> Tuple[str, str]: - data = GetPtrString(data_error.data) - error = GetPtrString(data_error.error) +def get_data_error(data_error: DataError) -> Tuple[str, str]: + data = get_ptr_string(data_error.data) + error = get_ptr_string(data_error.error) return data, error -def GetMultipleDataError( +def get_multiple_data_error( multiple_data_error: MultipleDataError, ) -> Tuple[str, str, str]: - data = GetPtrString(multiple_data_error.data) - other_data = GetPtrString(multiple_data_error.other_data) - error = GetPtrString(multiple_data_error.error) + data = get_ptr_string(multiple_data_error.data) + other_data = get_ptr_string(multiple_data_error.other_data) + error = get_ptr_string(multiple_data_error.error) return data, other_data, error + + +decode_map = { + c_void_p: get_ptr_string, + DataError: get_data_error, + MultipleDataError: get_multiple_data_error, +} diff --git a/wrappers/python/src/event.py b/wrappers/python/src/event.py new file mode 100644 index 0000000..778ce5e --- /dev/null +++ b/wrappers/python/src/event.py @@ -0,0 +1,86 @@ +from . import VPNStateChange +from enum import Enum +from typing import Callable + + +class StateType(Enum): + Enter = 1 + Leave = 2 + Wait = 3 + + +EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback" + +# A state transition decorator for classes +# To use this, make sure to register the class with `register_class_callbacks` +def class_state_transition(state: str, state_type: StateType) -> Callable: + def wrapper(func): + setattr(func, EDUVPN_CALLBACK_PROPERTY, (state, state_type)) + return func + + return wrapper + + +class EventHandler(object): + def __init__(self): + self.handlers = {} + + def change_class_callbacks(self, cls, add=True) -> None: + # Loop over method names + for method_name in dir(cls): + try: + # Get the method + method = getattr(cls, method_name) + except: + # Unable to get a value, go to the next + continue + + # If it has a callback defined, add it to the events + method_value = getattr(method, EDUVPN_CALLBACK_PROPERTY, None) + if method_value: + state, state_type = method_value + + if add: + self.add_event(state, state_type, method) + else: + self.remove_event(state, state_type, method) + + def remove_event(self, state: str, state_type: StateType, func: Callable): + for key, values in self.handlers.copy().items(): + if key == (state, state_type): + values.remove(func) + if not values: + del self.handlers[key] + else: + self.handlers[key] = values + + def add_event(self, state: str, state_type: StateType, func: Callable): + if (state, state_type) not in self.handlers: + self.handlers[(state, state_type)] = [] + self.handlers[(state, state_type)].append(func) + + # A decorator for standalone functions + def on(self, state: str, state_type: StateType) -> Callable: + def wrapped_f(func): + self.add_event(state, state_type, func) + return func + + return wrapped_f + + def run_state( + self, state: str, other_state: str, state_type: StateType, data: str + ) -> None: + if (state, state_type) not in self.handlers: + return + for func in self.handlers[(state, state_type)]: + func(other_state, data) + + def run(self, old_state: str, new_state: str, data: str) -> None: + if old_state == new_state: + return + + # First run leave transitions, then enter + # The state is done when the wait event finishes + self.run_state(old_state, new_state, StateType.Leave, data) + self.run_state(new_state, old_state, StateType.Enter, data) + self.run_state(new_state, old_state, StateType.Wait, data) diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py index 03757df..76a08ab 100644 --- a/wrappers/python/src/main.py +++ b/wrappers/python/src/main.py @@ -1,178 +1,29 @@ -from . import lib, VPNStateChange, GetDataError, GetMultipleDataError, GetPtrString -from ctypes import * -from enum import Enum -from typing import Callable, Optional, Tuple -from functools import wraps +from . import lib, VPNStateChange, encode_args, decode_res +from typing import Optional, Tuple import threading +from .event import StateType, EventHandler +eduvpn_objects = {} -class StateType(Enum): - Enter = 1 - Leave = 2 - Wait = 3 - -EDUVPN_CALLBACK_PROPERTY = '_eduvpn_property_callback' - -# A state transition decorator for classes -# To use this, make sure to register the class with `register_class_callbacks` -def class_state_transition(state: str, state_type: StateType) -> Callable: - def wrapper(func): - setattr(func, EDUVPN_CALLBACK_PROPERTY, (state, state_type)) - return func - return wrapper - -class EventHandler(object): - def __init__(self): - self.handlers = {} - - def remove_event(self, state: str, state_type: StateType, func: Callable): - for key, values in self.handlers.copy().items(): - if key == (state, state_type): - values.remove(func) - if not values: - del self.handlers[key] - else: - self.handlers[key] = values - - def add_event(self, state: str, state_type: StateType, func: Callable): - if (state, state_type) not in self.handlers: - self.handlers[(state, state_type)] = [] - self.handlers[(state, state_type)].append(func) - - # A decorator for standalone functions - def on(self, state: str, state_type: StateType) -> Callable: - def wrapped_f(func): - self.add_event(state, state_type, func) - return func - - return wrapped_f - - def run_state( - self, state: str, other_state: str, state_type: StateType, data: str - ) -> None: - if (state, state_type) not in self.handlers: - return - for func in self.handlers[(state, state_type)]: - func(other_state, data) - - def run(self, old_state: str, new_state: str, data: str) -> None: - if old_state == new_state: - return - - # First run leave transitions, then enter - # The state is done when the wait event finishes - self.run_state(old_state, new_state, StateType.Leave, data) - self.run_state(new_state, old_state, StateType.Enter, data) - self.run_state(new_state, old_state, StateType.Wait, data) - - -# Registers the python app with the Go code -# name: The name of the app to be registered -# state_callback: The callback to trigger whenever a state is changed -def Register( - name: str, config_directory: str, state_callback: Optional[Callable], debug: bool -) -> str: - if not state_callback: - return "No callback provided" - name_bytes = name.encode("utf-8") - dir_bytes = config_directory.encode("utf-8") - ptr_err = lib.Register(name_bytes, dir_bytes, state_callback, debug) - err_string = GetPtrString(ptr_err) - return err_string - - -def CancelOAuth(name: str) -> str: - name_bytes = name.encode("utf-8") - ptr_err = lib.CancelOAuth(name_bytes) - err_string = GetPtrString(ptr_err) - return err_string - - -def Deregister(name: str) -> str: - name_bytes = name.encode("utf-8") - ptr_err = lib.Deregister(name_bytes) - err_string = GetPtrString(ptr_err) - return err_string - - -def GetDiscoServers(name: str) -> Tuple[str, str]: - name_bytes = name.encode("utf-8") - servers, servers_err = GetDataError(lib.GetServersList(name_bytes)) - return servers, servers_err - - -def GetDiscoOrganizations(name: str) -> Tuple[str, str]: - name_bytes = name.encode("utf-8") - organizations, organizations_err = GetDataError( - lib.GetOrganizationsList(name_bytes) - ) - return organizations, organizations_err - - -def GetConnectConfig( - name: str, url: str, is_secure_internet: bool, force_tcp: bool -) -> Tuple[str, str, str]: - name_bytes = name.encode("utf-8") - url_bytes = url.encode("utf-8") - multiple_data_error = lib.GetConnectConfig( - name_bytes, url_bytes, is_secure_internet, force_tcp - ) - return GetMultipleDataError(multiple_data_error) - - -def SetConnected(name: str) -> str: - name_bytes = name.encode("utf-8") - ptr_err = lib.SetConnected(name_bytes) - err_string = GetPtrString(ptr_err) - return err_string - - -def SetDisconnected(name: str) -> str: - name_bytes = name.encode("utf-8") - ptr_err = lib.SetDisconnected(name_bytes) - err_string = GetPtrString(ptr_err) - return err_string +def add_as_global_object(eduvpn) -> bool: + global eduvpn_objects + if eduvpn.name not in eduvpn_objects: + eduvpn_objects[eduvpn.name] = eduvpn + return True + return False -def SetSearchServer(name: str) -> str: - name_bytes = name.encode("utf-8") - ptr_err = lib.SetSearchServer(name_bytes) - err_string = GetPtrString(ptr_err) - return err_string - - -def SetIdentifier(name: str, identifier: str) -> str: - name_bytes = name.encode("utf-8") - identifier_bytes = identifier.encode("utf-8") - ptr_err = lib.SetIdentifier(name_bytes, identifier_bytes) - err_string = GetPtrString(ptr_err) - return err_string - - -def GetIdentifier(name: str) -> Tuple[str, str]: - name_bytes = name.encode("utf-8") - identifier, identifier_err = GetDataError(lib.GetIdentifier(name_bytes)) - return identifier, identifier_err +def remove_as_global_object(eduvpn): + global eduvpn_objects + eduvpn_objects.pop(eduvpn.name, None) -# This has to be global as otherwise the callback is not alive -callback_function = None - - -def register_callback(eduvpn): - global callback_function - callback_function = VPNStateChange( - lambda old_state, new_state, data: eduvpn.callback( - old_state.decode(), new_state.decode(), data.decode() - ) - ) - - -def SetProfileID(name: str, profile_id: str) -> str: - name_bytes = name.encode("utf-8") - profile_bytes = profile_id.encode("utf-8") - error_string = lib.SetProfileID(name_bytes, profile_bytes) - return GetPtrString(error_string) +@VPNStateChange +def state_callback(name, old_state, new_state, data): + name = name.decode() + if name not in eduvpn_objects: + return + eduvpn_objects[name].callback(old_state.decode(), new_state.decode(), data.decode()) class EduVPN(object): @@ -180,40 +31,50 @@ class EduVPN(object): self.event_handler = EventHandler() self.name = name self.config_directory = config_directory - register_callback(self) # Callbacks that need to wait for specific events # The ask profile callback needs to wait for the UI thread to select a profile # This is stored in the profile_event self.profile_event: Optional[threading.Event] = None + @self.event.on("Ask_Profile", StateType.Wait) def wait_profile_event(old_state: str, profiles: str): if self.profile_event: self.profile_event.wait() + def go_function(self, func, *args): + # The functions all have at least one arg type which is the name of the client + args_gen = encode_args(list(args), func.argtypes[1:]) + res = func(self.name.encode("utf-8"), *(args_gen)) + return decode_res(func.restype)(res) + def cancel_oauth(self) -> None: - cancel_oauth_err = CancelOAuth(self.name) + cancel_oauth_err = self.go_function(lib.CancelOAuth) if cancel_oauth_err: raise Exception(cancel_oauth_err) def deregister(self) -> None: - deregister_err = Deregister(self.name) + deregister_err = self.go_function(lib.Deregister) + remove_as_global_object(self) if deregister_err: raise Exception(deregister_err) def register(self, debug: bool = False) -> None: - register_err = Register( - self.name, self.config_directory, callback_function, debug + if not add_as_global_object(self): + raise Exception("Already registered") + + register_err = self.go_function( + lib.Register, self.config_directory, state_callback, debug ) if register_err: raise Exception(register_err) def get_disco_servers(self) -> str: - servers, servers_err = GetDiscoServers(self.name) + servers, servers_err = self.go_function(lib.GetDiscoServers) if servers_err: raise Exception(servers_err) @@ -221,20 +82,22 @@ class EduVPN(object): return servers def get_disco_organizations(self) -> str: - organizations, organizations_err = GetDiscoOrganizations(self.name) + organizations, organizations_err = self.go_function(lib.GetDiscoOrganizations) if organizations_err: raise Exception(organizations_err) return organizations - def get_config(self, url: str, is_secure_internet: bool = False, force_tcp: bool = False): + def get_config( + self, url: str, is_secure_internet: bool = False, force_tcp: bool = False + ): # Because it could be the case that a profile callback is started, store a threading event # In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set # The event is set in self.set_profile self.profile_event = threading.Event() - config, config_type, config_err = GetConnectConfig( - self.name, url, is_secure_internet, force_tcp + config, config_type, config_err = self.go_function( + lib.GetConnectConfig, url, is_secure_internet, force_tcp ) if config_err: @@ -255,19 +118,19 @@ class EduVPN(object): return self.get_config(url, True, force_tcp) def set_connected(self) -> None: - connect_err = SetConnected(self.name) + connect_err = self.go_function(lib.SetConnected) if connect_err: raise Exception(connect_err) def set_disconnected(self) -> None: - disconnect_err = SetDisconnected(self.name) + disconnect_err = self.go_function(lib.SetDisconnected) if disconnect_err: raise Exception(disconnect_err) def get_identifier(self) -> str: - identifier, identifier_err = GetIdentifier(self.name) + identifier, identifier_err = self.go_function(lib.GetIdentifier) if identifier_err: raise Exception(identifier_err) @@ -275,43 +138,22 @@ class EduVPN(object): return identifier def set_identifier(self, identifier: str) -> None: - identifier_err = SetIdentifier(self.name, identifier) + identifier_err = self.go_function(lib.SetIdentifier, identifier) if identifier_err: raise Exception(identifier_err) def set_search_server(self) -> None: - search_err = SetSearchServer(self.name) + search_err = self.go_function(lib.SetSearchServer) if search_err: raise Exception(search_err) - def change_class_callbacks(self, cls, add=True) -> None: - # Loop over method names - for method_name in dir(cls): - - try: - # Get the method - method = getattr(cls, method_name) - except: - # Unable to get a value, go to the next - continue - - # If it has a callback defined, add it to the events - method_value = getattr(method, EDUVPN_CALLBACK_PROPERTY, None) - if method_value: - state, state_type = method_value - - if add: - self.event.add_event(state, state_type, method) - else: - self.event.remove_event(state, state_type, method) - def remove_class_callbacks(self, cls) -> None: - self.change_class_callbacks(cls, add=False) + self.event_handler.change_class_callbacks(cls, add=False) def register_class_callbacks(self, cls) -> None: - self.change_class_callbacks(cls) + self.event_handler.change_class_callbacks(cls) @property def event(self) -> EventHandler: @@ -322,7 +164,7 @@ class EduVPN(object): def set_profile(self, profile_id: str) -> None: # Set the profile id - profile_err = SetProfileID(self.name, profile_id) + profile_err = self.go_function(lib.SetProfileID, profile_id) if profile_err: raise Exception(profile_err) |
