diff options
Diffstat (limited to 'wrappers/python')
| -rw-r--r-- | wrappers/python/main.py | 92 | ||||
| -rw-r--r-- | wrappers/python/src/__init__.py | 36 | ||||
| -rw-r--r-- | wrappers/python/src/main.py | 189 | ||||
| -rw-r--r-- | wrappers/python/tests.py | 29 |
4 files changed, 258 insertions, 88 deletions
diff --git a/wrappers/python/main.py b/wrappers/python/main.py index 1c1afd7..c887fed 100644 --- a/wrappers/python/main.py +++ b/wrappers/python/main.py @@ -1,32 +1,90 @@ import eduvpncommon.main as eduvpn import webbrowser +import json +# Asks the user for a profile index +# It loops up until a valid input is given +def ask_profile_input(total: int) -> int: + profile_index = None -_eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "configs") + while profile_index is None: + try: + profile_index = int( + input("Please select a profile by inputting a number (e.g. 1): ") + ) + if (profile_index > total) or (profile_index < 1): + print("Invalid profile range") + profile_index = None + except ValueError: + print("Please enter a valid input") + # The profile is one based, move to zero based input + return profile_index - 1 -@_eduvpn.event.on("OAuth_Started", eduvpn.StateType.Enter) -def oauth_initialized(url): - print(f"Got OAUTH url {url}") - webbrowser.open(url) +# Sets up the callbacks using the provided class +def setup_callbacks(_eduvpn: eduvpn.EduVPN) -> None: + # The callback that starst OAuth + # It needs to open the URL in the web browser + @_eduvpn.event.on("OAuth_Started", eduvpn.StateType.Enter) + def oauth_initialized(old_state: str, url: str) -> None: + print(f"Got OAuth URL {url}, old state: {old_state}") + webbrowser.open(url) + # The callback which asks the user for a profile + @_eduvpn.event.on("Ask_Profile", eduvpn.StateType.Enter) + def ask_profile(old_state: str, profiles: str): + print("Multiple profiles found, you need to select a profile, old state: {old_state}") -@_eduvpn.event.on("Ask_Profile", eduvpn.StateType.Enter) -def ask_profile(profiles): - print("ASK PROFILE CB", profiles) - _eduvpn.set_profile("prefer-openvpn") + # Parse the profiles as JSON + data = json.loads(profiles) + # Get a lits of profiles + profile_strings = [x["profile_id"] for x in data["info"]["profile_list"]] + total_profiles = len(profile_strings) -success = _eduvpn.register(debug=True) + # Create a list of the strings to standard output + for idx, profile in enumerate(profile_strings): + print(f"{idx+1}. {profile}") -if not success: - print("failed to register") + # Get the profile index from the user + profile_index = ask_profile_input(total_profiles) -print(_eduvpn.get_disco()) + # Set the profile with the index + _eduvpn.set_profile(profile_strings[profile_index]) -config, error = _eduvpn.get_config_institute_access("https://eduvpn.jwijenbergh.com") -if error: - print("Got connect error", error) +# The main entry point +if __name__ == "__main__": + _eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "configs") + setup_callbacks(_eduvpn) -print(config) + # Register with the eduVPN-common library + try: + _eduvpn.register(debug=True) + except Exception as e: + print("Failed registering:", e) + + server = input( + "Which Institute Access server do you want to connect to? (e.g. https://eduvpn.example.com): " + ) + + # Ensure we have a valid http prefix + if not server.startswith("http"): + # https by default + server = "https://" + server + + # Get a Wireguard/OpenVPN config + try: + config, config_type = _eduvpn.get_config_institute_access(server) + except Exception as e: + print("Failed to connect:", e) + print(f"Got a config with type: {config_type} and contents:\n{config}") + + # Set the internal FSM state to connected + try: + _eduvpn.set_connected() + except Exception as e: + print("Failed to set connected:", e) + + # Save and exit + _eduvpn.deregister() diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py index c028f09..c96a1b2 100644 --- a/wrappers/python/src/__init__.py +++ b/wrappers/python/src/__init__.py @@ -2,6 +2,7 @@ from ctypes import * from collections import defaultdict import pathlib import platform +from typing import Tuple _lib_prefixes = defaultdict( lambda: "lib", @@ -30,15 +31,31 @@ class DataError(Structure): _fields_ = [("data", c_void_p), ("error", c_void_p)] +class MultipleDataError(Structure): + _fields_ = [("data", c_void_p), ("other_data", c_void_p), ("error", c_void_p)] + + VPNStateChange = CFUNCTYPE(None, c_char_p, c_char_p, c_char_p) # Exposed functions # We have to use c_void_p instead of c_char_p to free it properly # See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error -lib.GetConnectConfig.argtypes, lib.GetConnectConfig.restype = [c_char_p, c_char_p, c_int, c_int], DataError +lib.GetConnectConfig.argtypes, lib.GetConnectConfig.restype = [ + c_char_p, + c_char_p, + c_int, + c_int, +], MultipleDataError lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], c_void_p -lib.Register.argtypes, lib.Register.restype = [c_char_p, c_char_p, VPNStateChange, c_int], c_void_p -lib.GetOrganizationsList.argtypes, lib.GetOrganizationsList.restype = [c_char_p], DataError +lib.Register.argtypes, lib.Register.restype = [ + c_char_p, + c_char_p, + VPNStateChange, + c_int, +], c_void_p +lib.GetOrganizationsList.argtypes, lib.GetOrganizationsList.restype = [ + c_char_p +], DataError lib.GetServersList.argtypes, lib.GetServersList.restype = [c_char_p], DataError lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p @@ -47,7 +64,7 @@ lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p], c_void_p lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None -def GetPtrString(ptr): +def GetPtrString(ptr: c_void_p) -> str: if ptr: string = cast(ptr, c_char_p).value lib.FreeString(ptr) @@ -56,7 +73,16 @@ def GetPtrString(ptr): return "" -def GetDataError(data_error): +def GetDataError(data_error: DataError) -> Tuple[str, str]: data = GetPtrString(data_error.data) error = GetPtrString(data_error.error) return data, error + + +def GetMultipleDataError( + multiple_data_error: MultipleDataError, +) -> Tuple[str, str, str]: + data = GetPtrString(multiple_data_error.data) + other_data = GetPtrString(multiple_data_error.other_data) + error = GetPtrString(multiple_data_error.error) + return data, other_data, error diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py index 2b346e3..dd8f36a 100644 --- a/wrappers/python/src/main.py +++ b/wrappers/python/src/main.py @@ -1,6 +1,7 @@ -from . import lib, VPNStateChange, GetDataError, GetPtrString +from . import lib, VPNStateChange, GetDataError, GetMultipleDataError, GetPtrString from ctypes import * from enum import Enum +from typing import Callable, Optional, Tuple class StateType(Enum): @@ -8,47 +9,94 @@ class StateType(Enum): Leave = 2 +class EventHandler(object): + def __init__(self): + self.handlers = {} + + def on(self, state: str, state_type: StateType) -> Callable: + def wrapped_f(func): + if (state, state_type) not in self.handlers: + self.handlers[(state, state_type)] = [] + self.handlers[(state, state_type)].append(func) + return func + + return wrapped_f + + def run_state(self, state: str, other_state: str, state_type: StateType, data: str) -> None: + if (state, state_type) not in self.handlers: + return + for func in self.handlers[(state, state_type)]: + func(other_state, data) + + def run(self, old_state: str, new_state: str, data: str) -> None: + if old_state == new_state: + return + self.run_state(old_state, new_state, StateType.Leave, data) + self.run_state(new_state, old_state, StateType.Enter, data) + + # Registers the python app with the Go code # name: The name of the app to be registered # state_callback: The callback to trigger whenever a state is changed -def Register(name, config_directory, state_callback, debug): +def Register( + name: str, config_directory: str, state_callback: Optional[Callable], debug: bool +) -> str: + if not state_callback: + return "No callback provided" name_bytes = name.encode("utf-8") dir_bytes = config_directory.encode("utf-8") ptr_err = lib.Register(name_bytes, dir_bytes, state_callback, debug) err_string = GetPtrString(ptr_err) return err_string -def CancelOAuth(name): + +def CancelOAuth(name: str) -> str: name_bytes = name.encode("utf-8") ptr_err = lib.CancelOAuth(name_bytes) err_string = GetPtrString(ptr_err) return err_string -def Deregister(name): + +def Deregister(name: str) -> str: name_bytes = name.encode("utf-8") ptr_err = lib.Deregister(name_bytes) err_string = GetPtrString(ptr_err) return err_string -def GetDiscoServers(name): + +def GetDiscoServers(name: str) -> Tuple[str, str]: + name_bytes = name.encode("utf-8") + servers, servers_err = GetDataError(lib.GetServersList(name_bytes)) + return servers, servers_err + + +def GetDiscoOrganizations(name: str) -> Tuple[str, str]: name_bytes = name.encode("utf-8") - servers, serversErr = GetDataError(lib.GetServersList(name_bytes)) - organizations, organizationsErr = GetDataError(lib.GetOrganizationsList(name_bytes)) - return servers, serversErr, organizations, organizationsErr + organizations, organizations_err = GetDataError( + lib.GetOrganizationsList(name_bytes) + ) + return organizations, organizations_err + -def GetConnectConfig(name, url, is_secure_internet, force_tcp): +def GetConnectConfig( + name: str, url: str, is_secure_internet: bool, force_tcp: bool +) -> Tuple[str, str, str]: name_bytes = name.encode("utf-8") url_bytes = url.encode("utf-8") - data_error = lib.GetConnectConfig(name_bytes, url_bytes, is_secure_internet, force_tcp) - return GetDataError(data_error) + multiple_data_error = lib.GetConnectConfig( + name_bytes, url_bytes, is_secure_internet, force_tcp + ) + return GetMultipleDataError(multiple_data_error) + -def SetConnected(name): +def SetConnected(name: str) -> str: name_bytes = name.encode("utf-8") ptr_err = lib.SetConnected(name_bytes) err_string = GetPtrString(ptr_err) return err_string -def SetDisconnected(name): + +def SetDisconnected(name: str) -> str: name_bytes = name.encode("utf-8") ptr_err = lib.SetDisconnected(name_bytes) err_string = GetPtrString(ptr_err) @@ -68,7 +116,7 @@ def register_callback(eduvpn): ) -def SetProfileID(name, profile_id) -> str: +def SetProfileID(name: str, profile_id: str) -> str: name_bytes = name.encode("utf-8") profile_bytes = profile_id.encode("utf-8") error_string = lib.SetProfileID(name_bytes, profile_bytes) @@ -76,68 +124,93 @@ def SetProfileID(name, profile_id) -> str: class EduVPN(object): - def __init__(self, name, config_directory): + def __init__(self, name: str, config_directory: str): self.event_handler = EventHandler() self.name = name self.config_directory = config_directory register_callback(self) - def cancel_oauth(self) -> str: - return CancelOAuth(self.name) + def cancel_oauth(self) -> None: + cancel_oauth_err = CancelOAuth(self.name) - def deregister(self) -> str: - return Deregister(self.name) + if cancel_oauth_err: + raise Exception(cancel_oauth_err) - def register(self, debug=False) -> bool: - return Register(self.name, self.config_directory, callback_function, debug) == "" + def deregister(self) -> None: + deregister_err = Deregister(self.name) - def get_disco(self): - return GetDiscoServers(self.name) + if deregister_err: + raise Exception(deregister_err) - def get_config_institute_access(self, url, force_tcp=False): - return GetConnectConfig(self.name, url, False, force_tcp) + def register(self, debug: bool = False) -> None: + register_err = Register( + self.name, self.config_directory, callback_function, debug + ) - def get_config_secure_internet(self, url, force_tcp=False): - return GetConnectConfig(self.name, url, True, force_tcp) + if register_err: + raise Exception(register_err) - def set_disconnected(self): - return SetDisconnected(self.name) + def get_disco_servers(self) -> str: + servers, servers_err = GetDiscoServers(self.name) - def set_connected(self): - return SetConnected(self.name) + if servers_err: + raise Exception(servers_err) - @property - def event(self): - return self.event_handler + return servers - def callback(self, old_state, new_state, data): - self.event.run(old_state, new_state, data) + def get_disco_organizations(self) -> str: + organizations, organizations_err = GetDiscoOrganizations(self.name) - def set_profile(self, profile_id) -> str: - return SetProfileID(self.name, profile_id) + if organizations_err: + raise Exception(organizations_err) + return organizations -class EventHandler(object): - def __init__(self): - self.handlers = {} + def get_config_institute_access( + self, url: str, force_tcp: bool = False + ) -> Tuple[str, str]: + config, config_type, config_err = GetConnectConfig( + self.name, url, False, force_tcp + ) - def on(self, state, state_type): - def wrapped_f(func): - if (state, state_type) not in self.handlers: - self.handlers[(state, state_type)] = [] - self.handlers[(state, state_type)].append(func) - return func + if config_err: + raise Exception(config_err) - return wrapped_f + return config, config_type - def run_state(self, state, state_type, data): - if (state, state_type) not in self.handlers: - return - for func in self.handlers[(state, state_type)]: - func(data) + def get_config_secure_internet( + self, url: str, force_tcp: bool = False + ) -> Tuple[str, str]: + config, config_type, config_err = GetConnectConfig( + self.name, url, True, force_tcp + ) - def run(self, old_state, new_state, data): - if old_state == new_state: - return - self.run_state(old_state, StateType.Leave, data) - self.run_state(new_state, StateType.Enter, data) + if config_err: + raise Exception(config_err) + + return config, config_type + + def set_disconnected(self) -> None: + disconnect_err = SetDisconnected(self.name) + + if disconnect_err: + raise Exception(disconnect_err) + + def set_connected(self) -> None: + connect_err = SetConnected(self.name) + + if connect_err: + raise Exception(connect_err) + + @property + def event(self) -> EventHandler: + return self.event_handler + + def callback(self, old_state: str, new_state: str, data: str) -> None: + self.event.run(old_state, new_state, data) + + def set_profile(self, profile_id: str) -> None: + profile_err = SetProfileID(self.name, profile_id) + + if profile_err: + raise Exception(profile_err) diff --git a/wrappers/python/tests.py b/wrappers/python/tests.py index f006646..60ed79e 100644 --- a/wrappers/python/tests.py +++ b/wrappers/python/tests.py @@ -13,20 +13,33 @@ from selenium_eduvpn import login_eduvpn class ConfigTests(unittest.TestCase): def testConfig(self): - self._eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "testconfigs") - assert self._eduvpn.register() - @self._eduvpn.event.on("OAuth_Started", eduvpn.StateType.Enter) - def oauth_initialized(url): + _eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "testconfigs") + # This can throw an exception + _eduvpn.register() + @_eduvpn.event.on("OAuth_Started", eduvpn.StateType.Enter) + def oauth_initialized(old_state, url): login_eduvpn(url) server_uri = os.getenv("SERVER_URI") if not server_uri: self.fail("No SERVER_URI environment variable given") - config, error = self._eduvpn.get_config_institute_access(server_uri) - - if error != "": - self.fail(f"Got error: {error} when connecting to {server_uri}") + # This can throw an exception + _eduvpn.get_config_institute_access(server_uri) + + # Deregister + _eduvpn.deregister() + + def testDoubleRegister(self): + _eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "testconfigs") + # This can throw an exception + _eduvpn.register() + # This should throw + try: + _eduvpn.register() + except Exception as e: + return + self.fail("No exception thrown on second register") if __name__ == "__main__": unittest.main() |
