summaryrefslogtreecommitdiff
path: root/wrappers/python/src/main.py
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-06-21 18:19:11 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-06-21 18:19:11 +0200
commit8fbe53fb2f90ca7c7410621581abca35bc3e749c (patch)
tree82748714a37ba634123bed2d5f731a7c5f0ff5fa /wrappers/python/src/main.py
parent1f20c4069d354167548241ea09d37dad82ecf10a (diff)
Python/Exports: Separate events and use a map with the name for callbacks
Also adds a helper to call Go functions with the proper encoding from Python :^)
Diffstat (limited to 'wrappers/python/src/main.py')
-rw-r--r--wrappers/python/src/main.py256
1 files changed, 49 insertions, 207 deletions
diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py
index 03757df..76a08ab 100644
--- a/wrappers/python/src/main.py
+++ b/wrappers/python/src/main.py
@@ -1,178 +1,29 @@
-from . import lib, VPNStateChange, GetDataError, GetMultipleDataError, GetPtrString
-from ctypes import *
-from enum import Enum
-from typing import Callable, Optional, Tuple
-from functools import wraps
+from . import lib, VPNStateChange, encode_args, decode_res
+from typing import Optional, Tuple
import threading
+from .event import StateType, EventHandler
+eduvpn_objects = {}
-class StateType(Enum):
- Enter = 1
- Leave = 2
- Wait = 3
-
-EDUVPN_CALLBACK_PROPERTY = '_eduvpn_property_callback'
-
-# A state transition decorator for classes
-# To use this, make sure to register the class with `register_class_callbacks`
-def class_state_transition(state: str, state_type: StateType) -> Callable:
- def wrapper(func):
- setattr(func, EDUVPN_CALLBACK_PROPERTY, (state, state_type))
- return func
- return wrapper
-
-class EventHandler(object):
- def __init__(self):
- self.handlers = {}
-
- def remove_event(self, state: str, state_type: StateType, func: Callable):
- for key, values in self.handlers.copy().items():
- if key == (state, state_type):
- values.remove(func)
- if not values:
- del self.handlers[key]
- else:
- self.handlers[key] = values
-
- def add_event(self, state: str, state_type: StateType, func: Callable):
- if (state, state_type) not in self.handlers:
- self.handlers[(state, state_type)] = []
- self.handlers[(state, state_type)].append(func)
-
- # A decorator for standalone functions
- def on(self, state: str, state_type: StateType) -> Callable:
- def wrapped_f(func):
- self.add_event(state, state_type, func)
- return func
-
- return wrapped_f
-
- def run_state(
- self, state: str, other_state: str, state_type: StateType, data: str
- ) -> None:
- if (state, state_type) not in self.handlers:
- return
- for func in self.handlers[(state, state_type)]:
- func(other_state, data)
-
- def run(self, old_state: str, new_state: str, data: str) -> None:
- if old_state == new_state:
- return
-
- # First run leave transitions, then enter
- # The state is done when the wait event finishes
- self.run_state(old_state, new_state, StateType.Leave, data)
- self.run_state(new_state, old_state, StateType.Enter, data)
- self.run_state(new_state, old_state, StateType.Wait, data)
-
-
-# Registers the python app with the Go code
-# name: The name of the app to be registered
-# state_callback: The callback to trigger whenever a state is changed
-def Register(
- name: str, config_directory: str, state_callback: Optional[Callable], debug: bool
-) -> str:
- if not state_callback:
- return "No callback provided"
- name_bytes = name.encode("utf-8")
- dir_bytes = config_directory.encode("utf-8")
- ptr_err = lib.Register(name_bytes, dir_bytes, state_callback, debug)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def CancelOAuth(name: str) -> str:
- name_bytes = name.encode("utf-8")
- ptr_err = lib.CancelOAuth(name_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def Deregister(name: str) -> str:
- name_bytes = name.encode("utf-8")
- ptr_err = lib.Deregister(name_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def GetDiscoServers(name: str) -> Tuple[str, str]:
- name_bytes = name.encode("utf-8")
- servers, servers_err = GetDataError(lib.GetServersList(name_bytes))
- return servers, servers_err
-
-
-def GetDiscoOrganizations(name: str) -> Tuple[str, str]:
- name_bytes = name.encode("utf-8")
- organizations, organizations_err = GetDataError(
- lib.GetOrganizationsList(name_bytes)
- )
- return organizations, organizations_err
-
-
-def GetConnectConfig(
- name: str, url: str, is_secure_internet: bool, force_tcp: bool
-) -> Tuple[str, str, str]:
- name_bytes = name.encode("utf-8")
- url_bytes = url.encode("utf-8")
- multiple_data_error = lib.GetConnectConfig(
- name_bytes, url_bytes, is_secure_internet, force_tcp
- )
- return GetMultipleDataError(multiple_data_error)
-
-
-def SetConnected(name: str) -> str:
- name_bytes = name.encode("utf-8")
- ptr_err = lib.SetConnected(name_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def SetDisconnected(name: str) -> str:
- name_bytes = name.encode("utf-8")
- ptr_err = lib.SetDisconnected(name_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
+def add_as_global_object(eduvpn) -> bool:
+ global eduvpn_objects
+ if eduvpn.name not in eduvpn_objects:
+ eduvpn_objects[eduvpn.name] = eduvpn
+ return True
+ return False
-def SetSearchServer(name: str) -> str:
- name_bytes = name.encode("utf-8")
- ptr_err = lib.SetSearchServer(name_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def SetIdentifier(name: str, identifier: str) -> str:
- name_bytes = name.encode("utf-8")
- identifier_bytes = identifier.encode("utf-8")
- ptr_err = lib.SetIdentifier(name_bytes, identifier_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def GetIdentifier(name: str) -> Tuple[str, str]:
- name_bytes = name.encode("utf-8")
- identifier, identifier_err = GetDataError(lib.GetIdentifier(name_bytes))
- return identifier, identifier_err
+def remove_as_global_object(eduvpn):
+ global eduvpn_objects
+ eduvpn_objects.pop(eduvpn.name, None)
-# This has to be global as otherwise the callback is not alive
-callback_function = None
-
-
-def register_callback(eduvpn):
- global callback_function
- callback_function = VPNStateChange(
- lambda old_state, new_state, data: eduvpn.callback(
- old_state.decode(), new_state.decode(), data.decode()
- )
- )
-
-
-def SetProfileID(name: str, profile_id: str) -> str:
- name_bytes = name.encode("utf-8")
- profile_bytes = profile_id.encode("utf-8")
- error_string = lib.SetProfileID(name_bytes, profile_bytes)
- return GetPtrString(error_string)
+@VPNStateChange
+def state_callback(name, old_state, new_state, data):
+ name = name.decode()
+ if name not in eduvpn_objects:
+ return
+ eduvpn_objects[name].callback(old_state.decode(), new_state.decode(), data.decode())
class EduVPN(object):
@@ -180,40 +31,50 @@ class EduVPN(object):
self.event_handler = EventHandler()
self.name = name
self.config_directory = config_directory
- register_callback(self)
# Callbacks that need to wait for specific events
# The ask profile callback needs to wait for the UI thread to select a profile
# This is stored in the profile_event
self.profile_event: Optional[threading.Event] = None
+
@self.event.on("Ask_Profile", StateType.Wait)
def wait_profile_event(old_state: str, profiles: str):
if self.profile_event:
self.profile_event.wait()
+ def go_function(self, func, *args):
+ # The functions all have at least one arg type which is the name of the client
+ args_gen = encode_args(list(args), func.argtypes[1:])
+ res = func(self.name.encode("utf-8"), *(args_gen))
+ return decode_res(func.restype)(res)
+
def cancel_oauth(self) -> None:
- cancel_oauth_err = CancelOAuth(self.name)
+ cancel_oauth_err = self.go_function(lib.CancelOAuth)
if cancel_oauth_err:
raise Exception(cancel_oauth_err)
def deregister(self) -> None:
- deregister_err = Deregister(self.name)
+ deregister_err = self.go_function(lib.Deregister)
+ remove_as_global_object(self)
if deregister_err:
raise Exception(deregister_err)
def register(self, debug: bool = False) -> None:
- register_err = Register(
- self.name, self.config_directory, callback_function, debug
+ if not add_as_global_object(self):
+ raise Exception("Already registered")
+
+ register_err = self.go_function(
+ lib.Register, self.config_directory, state_callback, debug
)
if register_err:
raise Exception(register_err)
def get_disco_servers(self) -> str:
- servers, servers_err = GetDiscoServers(self.name)
+ servers, servers_err = self.go_function(lib.GetDiscoServers)
if servers_err:
raise Exception(servers_err)
@@ -221,20 +82,22 @@ class EduVPN(object):
return servers
def get_disco_organizations(self) -> str:
- organizations, organizations_err = GetDiscoOrganizations(self.name)
+ organizations, organizations_err = self.go_function(lib.GetDiscoOrganizations)
if organizations_err:
raise Exception(organizations_err)
return organizations
- def get_config(self, url: str, is_secure_internet: bool = False, force_tcp: bool = False):
+ def get_config(
+ self, url: str, is_secure_internet: bool = False, force_tcp: bool = False
+ ):
# Because it could be the case that a profile callback is started, store a threading event
# In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set
# The event is set in self.set_profile
self.profile_event = threading.Event()
- config, config_type, config_err = GetConnectConfig(
- self.name, url, is_secure_internet, force_tcp
+ config, config_type, config_err = self.go_function(
+ lib.GetConnectConfig, url, is_secure_internet, force_tcp
)
if config_err:
@@ -255,19 +118,19 @@ class EduVPN(object):
return self.get_config(url, True, force_tcp)
def set_connected(self) -> None:
- connect_err = SetConnected(self.name)
+ connect_err = self.go_function(lib.SetConnected)
if connect_err:
raise Exception(connect_err)
def set_disconnected(self) -> None:
- disconnect_err = SetDisconnected(self.name)
+ disconnect_err = self.go_function(lib.SetDisconnected)
if disconnect_err:
raise Exception(disconnect_err)
def get_identifier(self) -> str:
- identifier, identifier_err = GetIdentifier(self.name)
+ identifier, identifier_err = self.go_function(lib.GetIdentifier)
if identifier_err:
raise Exception(identifier_err)
@@ -275,43 +138,22 @@ class EduVPN(object):
return identifier
def set_identifier(self, identifier: str) -> None:
- identifier_err = SetIdentifier(self.name, identifier)
+ identifier_err = self.go_function(lib.SetIdentifier, identifier)
if identifier_err:
raise Exception(identifier_err)
def set_search_server(self) -> None:
- search_err = SetSearchServer(self.name)
+ search_err = self.go_function(lib.SetSearchServer)
if search_err:
raise Exception(search_err)
- def change_class_callbacks(self, cls, add=True) -> None:
- # Loop over method names
- for method_name in dir(cls):
-
- try:
- # Get the method
- method = getattr(cls, method_name)
- except:
- # Unable to get a value, go to the next
- continue
-
- # If it has a callback defined, add it to the events
- method_value = getattr(method, EDUVPN_CALLBACK_PROPERTY, None)
- if method_value:
- state, state_type = method_value
-
- if add:
- self.event.add_event(state, state_type, method)
- else:
- self.event.remove_event(state, state_type, method)
-
def remove_class_callbacks(self, cls) -> None:
- self.change_class_callbacks(cls, add=False)
+ self.event_handler.change_class_callbacks(cls, add=False)
def register_class_callbacks(self, cls) -> None:
- self.change_class_callbacks(cls)
+ self.event_handler.change_class_callbacks(cls)
@property
def event(self) -> EventHandler:
@@ -322,7 +164,7 @@ class EduVPN(object):
def set_profile(self, profile_id: str) -> None:
# Set the profile id
- profile_err = SetProfileID(self.name, profile_id)
+ profile_err = self.go_function(lib.SetProfileID, profile_id)
if profile_err:
raise Exception(profile_err)