diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-14 13:56:49 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-14 13:56:49 +0200 |
| commit | da83f54606c9c1d2786d87074ee17ed972d2e1b2 (patch) | |
| tree | 0be57934f9f467c87576abb0b457fb54b2d25d52 /wrappers/python | |
| parent | fd34e72da8c604517050ada7e883ba982829d985 (diff) | |
Refactor: Return without json
Diffstat (limited to 'wrappers/python')
| -rw-r--r-- | wrappers/python/main.py | 57 | ||||
| -rw-r--r-- | wrappers/python/src/__init__.py | 68 | ||||
| -rw-r--r-- | wrappers/python/src/discovery.py | 43 | ||||
| -rw-r--r-- | wrappers/python/src/event.py | 21 | ||||
| -rw-r--r-- | wrappers/python/src/main.py | 25 | ||||
| -rw-r--r-- | wrappers/python/src/server.py | 158 | ||||
| -rw-r--r-- | wrappers/python/tests.py | 2 |
7 files changed, 331 insertions, 43 deletions
diff --git a/wrappers/python/main.py b/wrappers/python/main.py index 1ab29cc..0bd2502 100644 --- a/wrappers/python/main.py +++ b/wrappers/python/main.py @@ -3,6 +3,8 @@ from eduvpn_common.state import State, StateType import webbrowser import json import sys +import time +from typing import List # Asks the user for a profile index # It loops up until a valid input is given @@ -27,6 +29,11 @@ def ask_profile_input(total: int) -> int: # Sets up the callbacks using the provided class def setup_callbacks(_eduvpn: eduvpn.EduVPN) -> None: # The callback that starst OAuth + @_eduvpn.event.on(State.NO_SERVER, StateType.Enter) + def no_server(old_state: str, servers) -> None: + for server in servers: + print(type(server)) + print(server) # It needs to open the URL in the web browser @_eduvpn.event.on(State.OAUTH_STARTED, StateType.Enter) def oauth_initialized(old_state: str, url: str) -> None: @@ -34,31 +41,30 @@ def setup_callbacks(_eduvpn: eduvpn.EduVPN) -> None: webbrowser.open(url) @_eduvpn.event.on(State.ASK_LOCATION, StateType.Enter) - def ask_location(old_state: str, locations: str): - print("Locations: ", locations) - _eduvpn.set_secure_location("NL") + def ask_location(old_state: str, locations: List[str]): + _eduvpn.set_secure_location(locations[1]) - # The callback which asks the user for a profile - @_eduvpn.event.on(State.ASK_PROFILE, StateType.Enter) - def ask_profile(old_state: str, profiles: str): - print("Multiple profiles found, you need to select a profile:") + ## The callback which asks the user for a profile + #@_eduvpn.event.on(State.ASK_PROFILE, StateType.Enter) + #def ask_profile(old_state: str, profiles: str): + # print("Multiple profiles found, you need to select a profile:") - # Parse the profiles as JSON - data = json.loads(profiles) + # # Parse the profiles as JSON + # data = json.loads(profiles) - # Get a lits of profiles - profile_strings = [x["profile_id"] for x in data["info"]["profile_list"]] - total_profiles = len(profile_strings) + # # Get a lits of profiles + # profile_strings = [x["profile_id"] for x in data["info"]["profile_list"]] + # total_profiles = len(profile_strings) - # Create a list of the strings to standard output - for idx, profile in enumerate(profile_strings): - print(f"{idx+1}. {profile}") + # # Create a list of the strings to standard output + # for idx, profile in enumerate(profile_strings): + # print(f"{idx+1}. {profile}") - # Get the profile index from the user - profile_index = ask_profile_input(total_profiles) + # # Get the profile index from the user + # profile_index = ask_profile_input(total_profiles) - # Set the profile with the index - _eduvpn.set_profile(profile_strings[profile_index]) + # # Set the profile with the index + # _eduvpn.set_profile(profile_strings[profile_index]) # The main entry point @@ -72,18 +78,13 @@ if __name__ == "__main__": except Exception as e: print("Failed registering:", e) - server = input( - "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): " - ) - - # Ensure we have a valid http prefix - if not server.startswith("http"): - # https by default - server = "https://" + server + #server = input( + # "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): " + #) # Get a Wireguard/OpenVPN config try: - config, config_type = _eduvpn.get_config_custom_server(server) + config, config_type = _eduvpn.get_config_secure_internet("https://idp.geant.org") print(f"Got a config with type: {config_type} and contents:\n{config}") except Exception as e: print("Failed to connect:", e) diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py index 5b63651..db5484f 100644 --- a/wrappers/python/src/__init__.py +++ b/wrappers/python/src/__init__.py @@ -40,12 +40,66 @@ class ErrorLevel(Enum): ERR_OTHER = 0 ERR_INFO = 1 +class cServerLocations(Structure): + _fields_ = [ + ("locations", POINTER(c_char_p)), + ("total_locations", c_size_t) + ] + +class cDiscoveryOrganization(Structure): + _fields_ = [ + ("display_name", c_char_p), + ("org_id", c_char_p), + ("secure_internet_home", c_char_p), + ("keyword_list", c_char_p), + ] + +class cDiscoveryOrganizations(Structure): + _fields_ = [ + ("version", c_ulonglong), + ("organizations", POINTER(POINTER(cDiscoveryOrganization))), + ("total_organizations", c_size_t), + ] + +class cServerProfile(Structure): + _fields_ = [ + ("identifier", c_char_p), + ("display_name", c_char_p), + ("default_gateway", c_int), + ] + +class cServerProfiles(Structure): + _fields_ = [ + ("current", c_int), + ("profiles", POINTER(POINTER(cServerProfile))), + ("total_profiles", c_size_t), + ] + +class cServer(Structure): + _fields_ = [ + ("identifier", c_char_p), + ("display_name", c_char_p), + ("country_code", c_char_p), + ("support_contact", POINTER(c_char_p)), + ("total_support_contact", c_size_t), + ("profiles", POINTER(cServerProfiles)), + ("expire_time", c_ulonglong), + ] + +class cServers(Structure): + _fields_ = [ + ("custom_servers", POINTER(POINTER(cServer))), + ("total_custom", c_size_t), + ("institute_servers", POINTER(POINTER(cServer))), + ("total_institute", c_size_t), + ("secure_internet", POINTER(cServer)), + ] class DataError(Structure): _fields_ = [("data", c_void_p), ("error", c_void_p)] -VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_char_p) +VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) # Exposed functions # We have to use c_void_p instead of c_char_p to free it properly @@ -77,7 +131,7 @@ lib.Register.argtypes, lib.Register.restype = [ ], c_void_p lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [ c_char_p -], DataError +], c_void_p lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p @@ -96,8 +150,12 @@ lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p +lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None +lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [c_void_p], None +lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int +lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], c_void_p class WrappedError: @@ -139,6 +197,8 @@ def get_ptr_error(ptr: c_void_p) -> Optional[WrappedError]: if not error_json: return None + if "level" not in error_json: + return error_string level = error_json["level"] traceback = error_json["traceback"] cause = error_json["cause"] @@ -149,6 +209,9 @@ def get_error(ptr: c_void_p) -> str: error = get_ptr_error(ptr) if not error: return "" + + if not isinstance(error, WrappedError): + return error return error.cause @@ -161,7 +224,6 @@ def get_data_error(data_error: DataError) -> Tuple[str, str]: def get_bool(boolInt: c_int) -> bool: return boolInt == 1 - decode_map = { c_int: get_bool, c_void_p: get_error, diff --git a/wrappers/python/src/discovery.py b/wrappers/python/src/discovery.py new file mode 100644 index 0000000..80c08cf --- /dev/null +++ b/wrappers/python/src/discovery.py @@ -0,0 +1,43 @@ +from . import lib, cDiscoveryOrganizations +from ctypes import cast, POINTER + + +class DiscoOrganization: + def __init__(self, display_name, org_id, secure_internet_home, keyword_list): + self.display_name = display_name + self.org_id = org_id + self.secure_internet_home = secure_internet_home + self.keyword_list = keyword_list + + +class DiscoOrganizations: + def __init__(self, version, organizations): + self.version = version + self.organizations = organizations + + +def get_disco_organization(ptr): + if not ptr: + return None + + current_organization = ptr.contents + display_name = current_organization.display_name.decode("utf-8") + org_id = current_organization.org_id.decode("utf-8") + secure_internet_home = current_organization.secure_internet_home.decode("utf-8") + keyword_list = current_organization.keyword_list.decode("utf-8") + return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list) + + +def get_disco_organizations(ptr): + if ptr: + orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents + organizations = [] + if orgs.organizations: + for i in range(orgs.total_organizations): + current = get_disco_organization(orgs.organizations[i]) + if current is None: + continue + organizations.append(current) + lib.FreeDiscoOrganizations(ptr) + return DiscoOrganizations(orgs.version, organizations) + return None diff --git a/wrappers/python/src/event.py b/wrappers/python/src/event.py index d0740f8..0e0f5ae 100644 --- a/wrappers/python/src/event.py +++ b/wrappers/python/src/event.py @@ -1,7 +1,8 @@ -from . import VPNStateChange +from . import VPNStateChange, get_ptr_string from enum import Enum from typing import Callable -from .state import StateType +from .state import State, StateType +from .server import get_locations, get_servers EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback" @@ -15,6 +16,15 @@ def class_state_transition(state: int, state_type: StateType) -> Callable: return wrapper +def convert_data(state: State, data): + if not data: + return None + if state is State.NO_SERVER: + return get_servers(data) + if state is State.OAUTH_STARTED: + return get_ptr_string(data) + if state is State.ASK_LOCATION: + return get_locations(data) class EventHandler(object): def __init__(self): @@ -73,6 +83,7 @@ class EventHandler(object): def run(self, old_state: int, new_state: int, data: str) -> None: # First run leave transitions, then enter # The state is done when the wait event finishes - self.run_state(old_state, new_state, StateType.Leave, data) - self.run_state(new_state, old_state, StateType.Enter, data) - self.run_state(new_state, old_state, StateType.Wait, data) + converted = convert_data(new_state, data) + self.run_state(old_state, new_state, StateType.Leave, converted) + self.run_state(new_state, old_state, StateType.Enter, converted) + self.run_state(new_state, old_state, StateType.Wait, converted) diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py index b37842f..cbeadb5 100644 --- a/wrappers/python/src/main.py +++ b/wrappers/python/src/main.py @@ -1,8 +1,10 @@ from . import lib, VPNStateChange, encode_args, decode_res from typing import Optional, Tuple import threading +from .discovery import get_disco_organizations from .event import EventHandler from .state import State, StateType +from .server import get_servers import json eduvpn_objects = {} @@ -26,7 +28,7 @@ def state_callback(name, old_state, new_state, data): name = name.decode() if name not in eduvpn_objects: return - eduvpn_objects[name].callback(State(old_state), State(new_state), data.decode()) + eduvpn_objects[name].callback(State(old_state), State(new_state), data) class EduVPN(object): @@ -58,6 +60,12 @@ class EduVPN(object): res = func(self.name.encode("utf-8"), *(args_gen)) return decode_res(func.restype)(res) + def go_function_custom_decode(self, func, decode_func, *args): + # The functions all have at least one arg type which is the name of the client + args_gen = encode_args(list(args), func.argtypes[1:]) + res = func(self.name.encode("utf-8"), *(args_gen)) + return decode_func(res) + def cancel_oauth(self) -> None: cancel_oauth_err = self.go_function(lib.CancelOAuth) @@ -91,10 +99,9 @@ class EduVPN(object): return servers def get_disco_organizations(self) -> str: - organizations, organizations_err = self.go_function(lib.GetDiscoOrganizations) - - if organizations_err: - raise Exception(organizations_err) + organizations = self.go_function_custom_decode(lib.GetDiscoOrganizations, decode_func=get_disco_organizations) + #if organizations_err: + # raise Exception(organizations_err) return organizations @@ -196,7 +203,7 @@ class EduVPN(object): def event(self) -> EventHandler: return self.event_handler - def callback(self, old_state: State, new_state: State, data: str) -> None: + def callback(self, old_state: State, new_state: State, data) -> None: self.event.run(old_state, new_state, data) def set_profile(self, profile_id: str) -> None: @@ -242,3 +249,9 @@ class EduVPN(object): def in_fsm_state(self, state_id: State) -> bool: return self.go_function(lib.InFSMState, state_id) + + def get_saved_servers_old(self) -> str: + return self.go_function(lib.GetSavedServersOLD) + + def get_saved_servers_new(self) -> str: + return self.go_function_custom_decode(lib.GetSavedServersNEW, decode_func=get_servers) diff --git a/wrappers/python/src/server.py b/wrappers/python/src/server.py new file mode 100644 index 0000000..b765ede --- /dev/null +++ b/wrappers/python/src/server.py @@ -0,0 +1,158 @@ +from . import lib, cServers, cServerLocations +from ctypes import cast, POINTER + + +class Profile: + def __init__(self, identifier, display_name, default_gateway: bool): + self.identifier = identifier + self.display_name = display_name + self.default_gateway = default_gateway + + def __str__(self): + return f"Profile: {self.display_name}" + + +class Server: + def __init__(self, url, display_name, profiles, current_profile, expire_time): + self.url = url + self.display_name = display_name + self.profiles = profiles + self.current_profile = None + if current_profile < len(profiles): + self.current_profile = profiles[current_profile] + self.expire_time = expire_time + + def __str__(self): + return f"Server: {self.url}, with current profile: {self.current_profile}" + + +class InstituteServer(Server): + def __init__( + self, url, display_name, support_contact, profiles, current_profile, expire_time + ): + super().__init__(url, display_name, profiles, current_profile, expire_time) + self.support_contact = support_contact + + def __str__(self): + return f"Institute Server: {self.display_name}" + + +class SecureInternetServer(Server): + def __init__( + self, + url, + display_name, + support_contact, + profiles, + current_profile, + expire_time, + country_code, + ): + super().__init__(url, display_name, profiles, current_profile, expire_time) + self.support_contact = support_contact + self.country_code = country_code + + def __str__(self): + return f"Secure Internet Server: {self.display_name} with country {self.country_code}" + + +def get_type_for_str(type_str: str): + if type_str is "secure_internet": + return SecureInternetServer + if type_str is "custom_server": + return Server + return InstituteServer + + +def get_server(ptr, _type=None): + if not ptr: + return None + + current_server = ptr.contents + if _type is None: + _type = get_type_for_str(current_server.server_type.decode("utf-8")) + + identifier = current_server.identifier.decode("utf-8") + display_name = current_server.display_name.decode("utf-8") + + if _type is not Server: + support_contact = [] + for i in range(current_server.total_support_contact): + support_contact.append(current_server.support_contact[i].decode("utf-8")) + profiles = [] + if not current_server.profiles: + return None + + _profiles = current_server.profiles.contents + current_profile = _profiles.current + for i in range(_profiles.total_profiles): + if not _profiles.profiles or not _profiles.profiles[i]: + return None + profile = _profiles.profiles[i].contents + profiles.append( + Profile( + profile.identifier.decode("utf-8"), + profile.display_name.decode("utf-8"), + profile.default_gateway == 1, + ) + ) + + if _type is SecureInternetServer: + return SecureInternetServer( + identifier, + display_name, + support_contact, + profiles, + current_profile, + current_server.expire_time, + current_server.country_code.decode("utf-8"), + ) + if _type is InstituteServer: + return InstituteServer( + identifier, + display_name, + support_contact, + profiles, + current_profile, + current_server.expire_time, + ) + return Server( + identifier, display_name, profiles, current_profile, current_server.expire_time + ) + + +def get_servers(ptr): + if ptr: + returned = [] + servers = cast(ptr, POINTER(cServers)).contents + if servers.custom_servers: + for i in range(servers.total_custom): + current = get_server(servers.custom_servers[i], Server) + if current is None: + continue + returned.append(current) + + if servers.institute_servers: + for i in range(servers.total_institute): + current = get_server(servers.institute_servers[i], InstituteServer) + if current is None: + continue + returned.append(current) + + if servers.secure_internet: + current = get_server(servers.secure_internet, SecureInternetServer) + if current is not None: + returned.append(current) + lib.FreeServers(ptr) + return returned + return None + +def get_locations(ptr): + if ptr: + locations = cast(ptr, POINTER(cServerLocations)).contents + location_list = [] + for i in range(locations.total_locations): + location_list.append(locations.locations[i].decode("utf-8")) + lib.FreeSecureLocations(ptr) + return location_list + return None diff --git a/wrappers/python/tests.py b/wrappers/python/tests.py index 60d3cce..679eda0 100644 --- a/wrappers/python/tests.py +++ b/wrappers/python/tests.py @@ -24,7 +24,7 @@ class ConfigTests(unittest.TestCase): @_eduvpn.event.on(State.OAUTH_STARTED, StateType.Enter) def oauth_initialized(old_state, url_json): - login_eduvpn(json.loads(url_json)) + login_eduvpn(url_json) server_uri = os.getenv("SERVER_URI") if not server_uri: |
