From 0a19c2dedcaaa177b420eac99149515d84508204 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Mon, 26 Sep 2022 16:47:35 +0200 Subject: Python: Move from src/ to eduvpn_common/ and absolufy imports --- wrappers/python/eduvpn_common/__init__.py | 270 +++++++++++++++++++++++++++++ wrappers/python/eduvpn_common/discovery.py | 126 ++++++++++++++ wrappers/python/eduvpn_common/error.py | 15 ++ wrappers/python/eduvpn_common/event.py | 109 ++++++++++++ wrappers/python/eduvpn_common/main.py | 253 +++++++++++++++++++++++++++ wrappers/python/eduvpn_common/server.py | 175 +++++++++++++++++++ wrappers/python/eduvpn_common/state.py | 24 +++ wrappers/python/setup.py | 4 +- wrappers/python/src/__init__.py | 270 ----------------------------- wrappers/python/src/discovery.py | 126 -------------- wrappers/python/src/error.py | 15 -- wrappers/python/src/event.py | 109 ------------ wrappers/python/src/main.py | 253 --------------------------- wrappers/python/src/server.py | 175 ------------------- wrappers/python/src/state.py | 24 --- 15 files changed, 974 insertions(+), 974 deletions(-) create mode 100644 wrappers/python/eduvpn_common/__init__.py create mode 100644 wrappers/python/eduvpn_common/discovery.py create mode 100644 wrappers/python/eduvpn_common/error.py create mode 100644 wrappers/python/eduvpn_common/event.py create mode 100644 wrappers/python/eduvpn_common/main.py create mode 100644 wrappers/python/eduvpn_common/server.py create mode 100644 wrappers/python/eduvpn_common/state.py delete mode 100644 wrappers/python/src/__init__.py delete mode 100644 wrappers/python/src/discovery.py delete mode 100644 wrappers/python/src/error.py delete mode 100644 wrappers/python/src/event.py delete mode 100644 wrappers/python/src/main.py delete mode 100644 wrappers/python/src/server.py delete mode 100644 wrappers/python/src/state.py (limited to 'wrappers/python') diff --git a/wrappers/python/eduvpn_common/__init__.py b/wrappers/python/eduvpn_common/__init__.py new file mode 100644 index 0000000..1406fa2 --- /dev/null +++ b/wrappers/python/eduvpn_common/__init__.py @@ -0,0 +1,270 @@ +from ctypes import * +from collections import defaultdict +import pathlib +import platform +from typing import Tuple, Optional +from typing import List +from eduvpn_common.error import WrappedError, ErrorLevel + +_lib_prefixes = defaultdict( + lambda: "lib", + { + "windows": "", + }, +) + +_lib_suffixes = defaultdict( + lambda: ".so", + { + "windows": ".dll", + "darwin": ".dylib", + }, +) + +_os = platform.system().lower() + +_libname = "eduvpn_common" +_libfile = f"{_lib_prefixes[_os]}{_libname}{_lib_suffixes[_os]}" + +lib = None + +# Try to load in the normal path +try: + lib = cdll.LoadLibrary(_libfile) +# Otherwise, library should have been copied to the lib/ folder +except: + lib = cdll.LoadLibrary(str(pathlib.Path(__file__).parent / "lib" / _libfile)) + + +class cError(Structure): + _fields_ = [ + ("level", c_int), + ("traceback", c_char_p), + ("cause", c_char_p), + ] + +class cServerLocations(Structure): + _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)] + + +class cDiscoveryOrganization(Structure): + _fields_ = [ + ("display_name", c_char_p), + ("org_id", c_char_p), + ("secure_internet_home", c_char_p), + ("keyword_list", c_char_p), + ] + + +class cDiscoveryOrganizations(Structure): + _fields_ = [ + ("version", c_ulonglong), + ("organizations", POINTER(POINTER(cDiscoveryOrganization))), + ("total_organizations", c_size_t), + ] + + +class cDiscoveryServer(Structure): + _fields_ = [ + ("authentication_url_template", c_char_p), + ("base_url", c_char_p), + ("country_code", c_char_p), + ("display_name", c_char_p), + ("keyword_list", c_char_p), + ("public_key_list", POINTER(c_char_p)), + ("total_public_keys", c_size_t), + ("server_type", c_char_p), + ("support_contact", POINTER(c_char_p)), + ("total_support_contact", c_size_t), + ] + + +class cDiscoveryServers(Structure): + _fields_ = [ + ("version", c_ulonglong), + ("servers", POINTER(POINTER(cDiscoveryServer))), + ("total_servers", c_size_t), + ] + + +class cServerProfile(Structure): + _fields_ = [ + ("identifier", c_char_p), + ("display_name", c_char_p), + ("default_gateway", c_int), + ] + + +class cServerProfiles(Structure): + _fields_ = [ + ("current", c_int), + ("profiles", POINTER(POINTER(cServerProfile))), + ("total_profiles", c_size_t), + ] + + +class cServer(Structure): + _fields_ = [ + ("identifier", c_char_p), + ("display_name", c_char_p), + ("server_type", c_char_p), + ("country_code", c_char_p), + ("support_contact", POINTER(c_char_p)), + ("total_support_contact", c_size_t), + ("profiles", POINTER(cServerProfiles)), + ("expire_time", c_ulonglong), + ] + + +class cServers(Structure): + _fields_ = [ + ("custom_servers", POINTER(POINTER(cServer))), + ("total_custom", c_size_t), + ("institute_servers", POINTER(POINTER(cServer))), + ("total_institute", c_size_t), + ("secure_internet", POINTER(cServer)), + ] + + +class DataError(Structure): + _fields_ = [("data", c_void_p), ("error", c_void_p)] + + +class ConfigError(Structure): + _fields_ = [("config", c_void_p), ("config_type", c_void_p), ("error", c_void_p)] + + +VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) + +# Exposed functions +# We have to use c_void_p instead of c_char_p to free it properly +# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error +lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [ + c_char_p +], c_void_p +lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [ + c_char_p, + c_char_p, +], c_void_p +lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [ + c_char_p, + c_char_p, +], c_void_p +lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [ + c_char_p, + c_char_p, + c_int, +], ConfigError +lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [ + c_char_p, + c_char_p, + c_int, +], ConfigError +lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [ + c_char_p, + c_char_p, + c_int, +], ConfigError +lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None +lib.Register.argtypes, lib.Register.restype = [ + c_char_p, + c_char_p, + VPNStateChange, + c_int, +], c_void_p +lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [ + c_char_p +], DataError +lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError +lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None +lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p +lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p +lib.ChangeSecureLocation.argtypes, lib.ChangeSecureLocation.restype = [ + c_char_p +], c_void_p +lib.SetSecureLocation.argtypes, lib.SetSecureLocation.restype = [ + c_char_p, + c_char_p, +], c_void_p +lib.SetConnected.argtypes, lib.SetConnected.restype = [c_char_p], c_void_p +lib.SetDisconnecting.argtypes, lib.SetDisconnecting.restype = [c_char_p], c_void_p +lib.SetConnecting.argtypes, lib.SetConnecting.restype = [c_char_p], c_void_p +lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c_void_p +lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p +lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int +lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p +lib.FreeProfiles.argtypes, lib.FreeProfiles.restype = [c_void_p], None +lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None +lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None +lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [ + c_void_p +], None +lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None +lib.FreeError.argtypes, lib.FreeError.restype = [c_void_p], None +lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None +lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None +lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int +lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], DataError + + +def encode_args(args, types): + for arg, t in zip(args, types): + # c_char_p needs the str to be encoded to bytes + if t is c_char_p: + arg = arg.encode("utf-8") + yield arg + + +def decode_res(t): + return decode_map.get(t, lambda x: x) + + +def get_ptr_string(ptr: c_void_p) -> str: + if ptr: + string = cast(ptr, c_char_p).value + lib.FreeString(ptr) + if string: + return string.decode("utf-8") + return "" + + +def get_ptr_list_strings( + strings: POINTER(c_char_p), total_strings: c_size_t +) -> List[str]: + if strings: + strings_list = [] + for i in range(total_strings): + strings_list.append(strings[i].decode("utf-8")) + return strings_list + return [] + +def get_error(ptr: c_void_p) -> Optional[WrappedError]: + if not ptr: + return None + err = cast(ptr, POINTER(cError)).contents + wrapped = WrappedError(err.traceback.decode("utf-8"), err.cause.decode("utf-8"), ErrorLevel(err.level)) + lib.FreeError(ptr) + return wrapped + +def get_config_error(config_error: ConfigError) -> Tuple[str, str, Optional[WrappedError]]: + config = get_ptr_string(config_error.config) + config_type = get_ptr_string(config_error.config_type) + err = get_error(config_error.error) + return config, config_type, err + +def get_data_error(data_error: DataError, data_conv=get_ptr_string) -> Tuple[str, Optional[WrappedError]]: + data = data_conv(data_error.data) + error = get_error(data_error.error) + return data, error + + +def get_bool(boolInt: c_int) -> bool: + return boolInt == 1 + + +decode_map = { + c_int: get_bool, + c_void_p: get_error, + DataError: get_data_error, + ConfigError: get_config_error, +} diff --git a/wrappers/python/eduvpn_common/discovery.py b/wrappers/python/eduvpn_common/discovery.py new file mode 100644 index 0000000..68741bc --- /dev/null +++ b/wrappers/python/eduvpn_common/discovery.py @@ -0,0 +1,126 @@ +from eduvpn_common import lib, cDiscoveryOrganizations, cDiscoveryServers, get_ptr_list_strings +from ctypes import cast, POINTER + + +class DiscoOrganization: + def __init__(self, display_name, org_id, secure_internet_home, keyword_list): + self.display_name = display_name + self.org_id = org_id + self.secure_internet_home = secure_internet_home + self.keyword_list = keyword_list + + def __str__(self): + return self.display_name + + +class DiscoOrganizations: + def __init__(self, version, organizations): + self.version = version + self.organizations = organizations + + +class DiscoServer: + def __init__( + self, + authentication_url_template, + base_url, + country_code, + display_name, + keyword_list, + public_keys, + server_type, + support_contacts, + ): + self.authentication_url_template = authentication_url_template + self.base_url = base_url + self.country_code = country_code + self.display_name = display_name + self.keyword_list = keyword_list + self.public_keys = public_keys + self.server_type = server_type + self.support_contacts = support_contacts + + def __str__(self): + return self.display_name + + +class DiscoServers: + def __init__(self, version, servers): + self.version = version + self.servers = servers + + +def get_disco_organization(ptr): + if not ptr: + return None + + current_organization = ptr.contents + display_name = current_organization.display_name.decode("utf-8") + org_id = current_organization.org_id.decode("utf-8") + secure_internet_home = current_organization.secure_internet_home.decode("utf-8") + keyword_list = current_organization.keyword_list.decode("utf-8") + return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list) + + +def get_disco_server(ptr): + if not ptr: + return None + + current_server = ptr.contents + authentication_url_template = current_server.authentication_url_template.decode( + "utf-8" + ) + base_url = current_server.base_url.decode("utf-8") + country_code = current_server.country_code.decode("utf-8") + display_name = current_server.display_name.decode("utf-8") + keyword_list = current_server.keyword_list.decode("utf-8") + public_keys = get_ptr_list_strings( + current_server.public_key_list, current_server.total_public_keys + ) + server_type = current_server.server_type.decode("utf-8") + support_contacts = get_ptr_list_strings( + current_server.support_contact, current_server.total_support_contact + ) + return DiscoServer( + authentication_url_template, + base_url, + country_code, + display_name, + keyword_list, + public_keys, + server_type, + support_contacts, + ) + + +def get_disco_servers(ptr): + if ptr: + svrs = cast(ptr, POINTER(cDiscoveryServers)).contents + + servers = [] + + if svrs.servers: + for i in range(svrs.total_servers): + current = get_disco_server(svrs.servers[i]) + + if current is None: + continue + servers.append(current) + lib.FreeDiscoServers(ptr) + return DiscoServers(svrs.version, servers) + return None + + +def get_disco_organizations(ptr): + if ptr: + orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents + organizations = [] + if orgs.organizations: + for i in range(orgs.total_organizations): + current = get_disco_organization(orgs.organizations[i]) + if current is None: + continue + organizations.append(current) + lib.FreeDiscoOrganizations(ptr) + return DiscoOrganizations(orgs.version, organizations) + return None diff --git a/wrappers/python/eduvpn_common/error.py b/wrappers/python/eduvpn_common/error.py new file mode 100644 index 0000000..50298bb --- /dev/null +++ b/wrappers/python/eduvpn_common/error.py @@ -0,0 +1,15 @@ +from enum import Enum + +class ErrorLevel(Enum): + ERR_OTHER = 0 + ERR_INFO = 1 + ERR_WARNING = 2 + ERR_FATAL = 3 + +class WrappedError(Exception): + def __init__(self, traceback: str, cause: str, level: ErrorLevel): + super(WrappedError, self).__init__(cause) + self.traceback = traceback + self.cause = cause + self.level = level + diff --git a/wrappers/python/eduvpn_common/event.py b/wrappers/python/eduvpn_common/event.py new file mode 100644 index 0000000..4532bef --- /dev/null +++ b/wrappers/python/eduvpn_common/event.py @@ -0,0 +1,109 @@ +from eduvpn_common import VPNStateChange, get_ptr_string +from enum import Enum +from typing import Callable +from eduvpn_common.state import State, StateType +from eduvpn_common.server import ( + get_locations, + get_transition_profiles, + get_transition_server, + get_servers, +) + + +EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback" + +# A state transition decorator for classes +# To use this, make sure to register the class with `register_class_callbacks` +def class_state_transition(state: int, state_type: StateType) -> Callable: + def wrapper(func): + setattr(func, EDUVPN_CALLBACK_PROPERTY, (state, state_type)) + return func + + return wrapper + + +def convert_data(state: State, data): + if not data: + return None + if state is State.NO_SERVER: + return get_servers(data) + if state is State.OAUTH_STARTED: + return get_ptr_string(data) + if state is State.ASK_LOCATION: + return get_locations(data) + if state is State.ASK_PROFILE: + return get_transition_profiles(data) + if state in [ + State.DISCONNECTED, + State.DISCONNECTING, + State.CONNECTING, + State.CONNECTED, + ]: + return get_transition_server(data) + + +class EventHandler(object): + def __init__(self): + self.handlers = {} + + def change_class_callbacks(self, cls, add=True) -> None: + # Loop over method names + for method_name in dir(cls): + try: + # Get the method + method = getattr(cls, method_name) + except: + # Unable to get a value, go to the next + continue + + # If it has a callback defined, add it to the events + method_value = getattr(method, EDUVPN_CALLBACK_PROPERTY, None) + if method_value: + state, state_type = method_value + + if add: + self.add_event(state, state_type, method) + else: + self.remove_event(state, state_type, method) + + def remove_event(self, state: int, state_type: StateType, func: Callable): + for key, values in self.handlers.copy().items(): + if key == (state, state_type): + values.remove(func) + if not values: + del self.handlers[key] + else: + self.handlers[key] = values + + def add_event(self, state: int, state_type: StateType, func: Callable): + if (state, state_type) not in self.handlers: + self.handlers[(state, state_type)] = [] + self.handlers[(state, state_type)].append(func) + + # A decorator for standalone functions + def on(self, state: int, state_type: StateType) -> Callable: + def wrapped_f(func): + self.add_event(state, state_type, func) + return func + + return wrapped_f + + def run_state( + self, state: int, other_state: int, state_type: StateType, data: str + ) -> None: + if (state, state_type) not in self.handlers: + return + for func in self.handlers[(state, state_type)]: + func(other_state, data) + + def run( + self, old_state: int, new_state: int, data: str, convert: bool = True + ) -> None: + # First run leave transitions, then enter + # The state is done when the wait event finishes + converted = data + if convert: + converted = convert_data(new_state, data) + self.run_state(old_state, new_state, StateType.Leave, converted) + self.run_state(new_state, old_state, StateType.Enter, converted) + self.run_state(new_state, old_state, StateType.Wait, converted) diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py new file mode 100644 index 0000000..3875ad9 --- /dev/null +++ b/wrappers/python/eduvpn_common/main.py @@ -0,0 +1,253 @@ +from eduvpn_common import lib, VPNStateChange, encode_args, decode_res, get_data_error +from typing import Optional, Tuple +import threading +from eduvpn_common.discovery import get_disco_organizations, get_disco_servers +from eduvpn_common.event import EventHandler +from eduvpn_common.state import State, StateType +from eduvpn_common.server import get_servers + +eduvpn_objects = {} + + +def add_as_global_object(eduvpn) -> bool: + global eduvpn_objects + if eduvpn.name not in eduvpn_objects: + eduvpn_objects[eduvpn.name] = eduvpn + return True + return False + + +def remove_as_global_object(eduvpn): + global eduvpn_objects + eduvpn_objects.pop(eduvpn.name, None) + + +@VPNStateChange +def state_callback(name, old_state, new_state, data): + name = name.decode() + if name not in eduvpn_objects: + return + eduvpn_objects[name].callback(State(old_state), State(new_state), data) + + +class EduVPN(object): + def __init__(self, name: str, config_directory: str): + self.event_handler = EventHandler() + self.name = name + self.config_directory = config_directory + + # Callbacks that need to wait for specific events + + # The ask profile callback needs to wait for the UI thread to select a profile + # This is stored in the profile_event + self.profile_event: Optional[threading.Event] = None + self.location_event: Optional[threading.Event] = None + + @self.event.on(State.ASK_PROFILE, StateType.Wait) + def wait_profile_event(old_state: int, profiles: str): + if self.profile_event: + self.profile_event.wait() + + @self.event.on(State.ASK_LOCATION, StateType.Wait) + def wait_location_event(old_state: int, locations: str): + if self.location_event: + self.location_event.wait() + + def go_function(self, func, *args): + # The functions all have at least one arg type which is the name of the client + args_gen = encode_args(list(args), func.argtypes[1:]) + res = func(self.name.encode("utf-8"), *(args_gen)) + return decode_res(func.restype)(res) + + def go_function_custom_decode(self, func, decode_func, *args): + # The functions all have at least one arg type which is the name of the client + args_gen = encode_args(list(args), func.argtypes[1:]) + res = func(self.name.encode("utf-8"), *(args_gen)) + return decode_func(res) + + def cancel_oauth(self) -> None: + cancel_oauth_err = self.go_function(lib.CancelOAuth) + + if cancel_oauth_err: + raise cancel_oauth_err + + def deregister(self) -> None: + self.go_function(lib.Deregister) + remove_as_global_object(self) + + def register(self, debug: bool = False) -> None: + if not add_as_global_object(self): + raise Exception("Already registered") + + register_err = self.go_function( + lib.Register, self.config_directory, state_callback, debug + ) + + if register_err: + raise register_err + + def get_disco_servers(self) -> str: + servers, servers_err = self.go_function_custom_decode( + lib.GetDiscoServers, decode_func=lambda x: get_data_error(x, get_disco_servers) + ) + + if servers_err: + raise servers_err + + return servers + + def get_disco_organizations(self) -> str: + organizations, organizations_err = self.go_function_custom_decode( + lib.GetDiscoOrganizations, decode_func=lambda x: get_data_error(x, get_disco_organizations) + ) + + if organizations_err: + raise organizations_err + + return organizations + + def remove_secure_internet(self): + remove_err = self.go_function(lib.RemoveSecureInternet) + + if remove_err: + raise remove_err + + def remove_institute_access(self, url: str): + remove_err = self.go_function(lib.RemoveInstituteAccess, url) + + if remove_err: + raise remove_err + + def remove_custom_server(self, url: str): + remove_err = self.go_function(lib.RemoveCustomServer, url) + + if remove_err: + raise remove_err + + def get_config(self, url: str, func: callable, force_tcp: bool = False): + # Because it could be the case that a profile callback is started, store a threading event + # In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set + # The event is set in self.set_profile + self.profile_event = threading.Event() + + config, config_type, config_err = self.go_function(func, url, force_tcp) + + self.profile_event = None + self.location_event = None + + if config_err: + raise config_err + + return config, config_type + + def get_config_custom_server( + self, url: str, force_tcp: bool = False + ) -> Tuple[str, str]: + return self.get_config(url, lib.GetConfigCustomServer, force_tcp) + + def get_config_institute_access( + self, url: str, force_tcp: bool = False + ) -> Tuple[str, str]: + return self.get_config(url, lib.GetConfigInstituteAccess, force_tcp) + + def get_config_secure_internet( + self, url: str, force_tcp: bool = False + ) -> Tuple[str, str]: + self.location_event = threading.Event() + return self.get_config(url, lib.GetConfigSecureInternet, force_tcp) + + def go_back(self) -> None: + # Ignore the error + self.go_function(lib.GoBack) + + def set_connected(self) -> None: + connect_err = self.go_function(lib.SetConnected) + + if connect_err: + raise connect_err + + def set_disconnecting(self) -> None: + disconnecting_err = self.go_function(lib.SetDisconnecting) + + if disconnecting_err: + raise disconnecting_err + + def set_connecting(self) -> None: + connecting_err = self.go_function(lib.SetConnecting) + + if connecting_err: + raise connecting_err + + def set_disconnected(self, cleanup=True) -> None: + disconnect_err = self.go_function(lib.SetDisconnected, cleanup) + + if disconnect_err: + raise disconnect_err + + def set_search_server(self) -> None: + search_err = self.go_function(lib.SetSearchServer) + + if search_err: + raise search_err + + def remove_class_callbacks(self, cls) -> None: + self.event_handler.change_class_callbacks(cls, add=False) + + def register_class_callbacks(self, cls) -> None: + self.event_handler.change_class_callbacks(cls) + + @property + def event(self) -> EventHandler: + return self.event_handler + + def callback(self, old_state: State, new_state: State, data) -> None: + self.event.run(old_state, new_state, data) + + def set_profile(self, profile_id: str) -> None: + # Set the profile id + profile_err = self.go_function(lib.SetProfileID, profile_id) + + # If there is a profile event, set it so that the wait callback finishes + # And so that the Go code can move to the next state + if self.profile_event: + self.profile_event.set() + + if profile_err: + raise profile_err + + def change_secure_location(self) -> None: + # Set the location by country code + self.location_event = threading.Event() + location_err = self.go_function(lib.ChangeSecureLocation) + + if location_err: + raise location_err + + def set_secure_location(self, country_code: str) -> None: + # Set the location by country code + location_err = self.go_function(lib.SetSecureLocation, country_code) + + # If there is a location event, set it so that the wait callback finishes + # And so that the Go code can move to the next state + if self.location_event: + self.location_event.set() + + if location_err: + raise location_err + + def renew_session(self) -> None: + renew_err = self.go_function(lib.RenewSession) + + if renew_err: + raise renew_err + + def should_renew_button(self) -> bool: + return self.go_function(lib.ShouldRenewButton) + + def in_fsm_state(self, state_id: State) -> bool: + return self.go_function(lib.InFSMState, state_id) + + def get_saved_servers(self) -> str: + return self.go_function_custom_decode( + lib.GetSavedServers, decode_func=lambda x: get_data_error(x, get_servers) + ) diff --git a/wrappers/python/eduvpn_common/server.py b/wrappers/python/eduvpn_common/server.py new file mode 100644 index 0000000..470f704 --- /dev/null +++ b/wrappers/python/eduvpn_common/server.py @@ -0,0 +1,175 @@ +from eduvpn_common import lib, cServer, cServers, cServerLocations, cServerProfiles +from ctypes import cast, POINTER, c_char_p +from datetime import datetime + + +class Profile: + def __init__(self, identifier, display_name, default_gateway: bool): + self.identifier = identifier + self.display_name = display_name + self.default_gateway = default_gateway + + def __str__(self): + return self.display_name + + +class Profiles: + def __init__(self, profiles, current): + self.profiles = profiles + self.current_index = current + + @property + def current(self): + if self.current_index < len(self.profiles): + return self.profiles[self.current_index] + return None + + +class Server: + def __init__(self, url, display_name, profiles=None, expire_time=0): + self.url = url + self.display_name = display_name + self.profiles = profiles + self.current_profile = None + self.expire_time = datetime.fromtimestamp(expire_time) + + def __str__(self): + return self.display_name + + +class InstituteServer(Server): + def __init__(self, url, display_name, support_contact, profiles, expire_time): + super().__init__(url, display_name, profiles, expire_time) + self.support_contact = support_contact + + +class SecureInternetServer(Server): + def __init__( + self, + org_id, + display_name, + support_contact, + profiles, + expire_time, + country_code, + ): + super().__init__(org_id, display_name, profiles, expire_time) + self.org_id = org_id + self.support_contact = support_contact + self.country_code = country_code + + +def get_type_for_str(type_str: str): + if type_str == "secure_internet": + return SecureInternetServer + if type_str == "custom_server": + return Server + return InstituteServer + + +def get_profiles(ptr): + if not ptr: + return [] + profiles = [] + _profiles = ptr.contents + current_profile = _profiles.current + if not _profiles.profiles: + return [] + for i in range(_profiles.total_profiles): + if not _profiles.profiles[i]: + continue + profile = _profiles.profiles[i].contents + profiles.append( + Profile( + profile.identifier.decode("utf-8"), + profile.display_name.decode("utf-8"), + profile.default_gateway == 1, + ) + ) + return Profiles(profiles, current_profile) + + +def get_server(ptr, _type=None): + if not ptr: + return None + + current_server = ptr.contents + if _type is None: + _type = get_type_for_str(current_server.server_type.decode("utf-8")) + + identifier = current_server.identifier.decode("utf-8") + display_name = current_server.display_name.decode("utf-8") + + if _type is not Server: + support_contact = [] + for i in range(current_server.total_support_contact): + support_contact.append(current_server.support_contact[i].decode("utf-8")) + profiles = get_profiles(current_server.profiles) + if _type is SecureInternetServer: + return SecureInternetServer( + identifier, + display_name, + support_contact, + profiles, + current_server.expire_time, + current_server.country_code.decode("utf-8"), + ) + if _type is InstituteServer: + return InstituteServer( + identifier, + display_name, + support_contact, + profiles, + current_server.expire_time, + ) + return Server(identifier, display_name, profiles, current_server.expire_time) + + +def get_transition_server(ptr): + server = get_server(cast(ptr, POINTER(cServer))) + lib.FreeServer(ptr) + return server + + +def get_transition_profiles(ptr): + profiles = get_profiles(cast(ptr, POINTER(cServerProfiles))) + lib.FreeProfiles(ptr) + return profiles + + +def get_servers(ptr): + if ptr: + returned = [] + servers = cast(ptr, POINTER(cServers)).contents + if servers.custom_servers: + for i in range(servers.total_custom): + current = get_server(servers.custom_servers[i], Server) + if current is None: + continue + returned.append(current) + + if servers.institute_servers: + for i in range(servers.total_institute): + current = get_server(servers.institute_servers[i], InstituteServer) + if current is None: + continue + returned.append(current) + + if servers.secure_internet: + current = get_server(servers.secure_internet, SecureInternetServer) + if current is not None: + returned.append(current) + lib.FreeServers(ptr) + return returned + return None + + +def get_locations(ptr): + if ptr: + locations = cast(ptr, POINTER(cServerLocations)).contents + location_list = [] + for i in range(locations.total_locations): + location_list.append(locations.locations[i].decode("utf-8")) + lib.FreeSecureLocations(ptr) + return location_list + return None diff --git a/wrappers/python/eduvpn_common/state.py b/wrappers/python/eduvpn_common/state.py new file mode 100644 index 0000000..5af004f --- /dev/null +++ b/wrappers/python/eduvpn_common/state.py @@ -0,0 +1,24 @@ +from enum import IntEnum + + +class StateType(IntEnum): + Enter = 1 + Leave = 2 + Wait = 3 + + +class State(IntEnum): + DEREGISTERED = 0 + NO_SERVER = 1 + ASK_LOCATION = 2 + SEARCH_SERVER = 3 + LOADING_SERVER = 4 + CHOSEN_SERVER = 5 + OAUTH_STARTED = 6 + AUTHORIZED = 7 + REQUEST_CONFIG = 8 + ASK_PROFILE = 9 + DISCONNECTED = 10 + DISCONNECTING = 11 + CONNECTING = 12 + CONNECTED = 13 diff --git a/wrappers/python/setup.py b/wrappers/python/setup.py index b2214e7..499627e 100755 --- a/wrappers/python/setup.py +++ b/wrappers/python/setup.py @@ -94,7 +94,7 @@ class bdist_wheel(_bdist_wheel): print(f"Building wheel for platform {self.plat_name}") # setuptools will only use paths inside the package for package_data, so we copy the library - tmp_lib = shutil.copy(f"{self.exports_lib_path}/{libpath}", "src/lib/") + tmp_lib = shutil.copy(f"{self.exports_lib_path}/{libpath}", "eduvpn_common/lib/") _bdist_wheel.run(self) os.remove(tmp_lib) @@ -104,7 +104,7 @@ setup( version="0.1.0", packages=["eduvpn_common"], python_requires=">=3.6", - package_dir={"eduvpn_common": "src"}, + package_dir={"eduvpn_common": "eduvpn_common"}, package_data={"eduvpn_common": [f"lib/*{_libname}*"]}, cmdclass={"bdist_wheel": bdist_wheel}, ) diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py deleted file mode 100644 index 5383259..0000000 --- a/wrappers/python/src/__init__.py +++ /dev/null @@ -1,270 +0,0 @@ -from ctypes import * -from collections import defaultdict -import pathlib -import platform -from typing import Tuple, Optional -from typing import List -from .error import WrappedError, ErrorLevel - -_lib_prefixes = defaultdict( - lambda: "lib", - { - "windows": "", - }, -) - -_lib_suffixes = defaultdict( - lambda: ".so", - { - "windows": ".dll", - "darwin": ".dylib", - }, -) - -_os = platform.system().lower() - -_libname = "eduvpn_common" -_libfile = f"{_lib_prefixes[_os]}{_libname}{_lib_suffixes[_os]}" - -lib = None - -# Try to load in the normal path -try: - lib = cdll.LoadLibrary(_libfile) -# Otherwise, library should have been copied to the lib/ folder -except: - lib = cdll.LoadLibrary(str(pathlib.Path(__file__).parent / "lib" / _libfile)) - - -class cError(Structure): - _fields_ = [ - ("level", c_int), - ("traceback", c_char_p), - ("cause", c_char_p), - ] - -class cServerLocations(Structure): - _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)] - - -class cDiscoveryOrganization(Structure): - _fields_ = [ - ("display_name", c_char_p), - ("org_id", c_char_p), - ("secure_internet_home", c_char_p), - ("keyword_list", c_char_p), - ] - - -class cDiscoveryOrganizations(Structure): - _fields_ = [ - ("version", c_ulonglong), - ("organizations", POINTER(POINTER(cDiscoveryOrganization))), - ("total_organizations", c_size_t), - ] - - -class cDiscoveryServer(Structure): - _fields_ = [ - ("authentication_url_template", c_char_p), - ("base_url", c_char_p), - ("country_code", c_char_p), - ("display_name", c_char_p), - ("keyword_list", c_char_p), - ("public_key_list", POINTER(c_char_p)), - ("total_public_keys", c_size_t), - ("server_type", c_char_p), - ("support_contact", POINTER(c_char_p)), - ("total_support_contact", c_size_t), - ] - - -class cDiscoveryServers(Structure): - _fields_ = [ - ("version", c_ulonglong), - ("servers", POINTER(POINTER(cDiscoveryServer))), - ("total_servers", c_size_t), - ] - - -class cServerProfile(Structure): - _fields_ = [ - ("identifier", c_char_p), - ("display_name", c_char_p), - ("default_gateway", c_int), - ] - - -class cServerProfiles(Structure): - _fields_ = [ - ("current", c_int), - ("profiles", POINTER(POINTER(cServerProfile))), - ("total_profiles", c_size_t), - ] - - -class cServer(Structure): - _fields_ = [ - ("identifier", c_char_p), - ("display_name", c_char_p), - ("server_type", c_char_p), - ("country_code", c_char_p), - ("support_contact", POINTER(c_char_p)), - ("total_support_contact", c_size_t), - ("profiles", POINTER(cServerProfiles)), - ("expire_time", c_ulonglong), - ] - - -class cServers(Structure): - _fields_ = [ - ("custom_servers", POINTER(POINTER(cServer))), - ("total_custom", c_size_t), - ("institute_servers", POINTER(POINTER(cServer))), - ("total_institute", c_size_t), - ("secure_internet", POINTER(cServer)), - ] - - -class DataError(Structure): - _fields_ = [("data", c_void_p), ("error", c_void_p)] - - -class ConfigError(Structure): - _fields_ = [("config", c_void_p), ("config_type", c_void_p), ("error", c_void_p)] - - -VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) - -# Exposed functions -# We have to use c_void_p instead of c_char_p to free it properly -# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error -lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [ - c_char_p -], c_void_p -lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [ - c_char_p, - c_char_p, -], c_void_p -lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [ - c_char_p, - c_char_p, -], c_void_p -lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [ - c_char_p, - c_char_p, - c_int, -], ConfigError -lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [ - c_char_p, - c_char_p, - c_int, -], ConfigError -lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [ - c_char_p, - c_char_p, - c_int, -], ConfigError -lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None -lib.Register.argtypes, lib.Register.restype = [ - c_char_p, - c_char_p, - VPNStateChange, - c_int, -], c_void_p -lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [ - c_char_p -], DataError -lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError -lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None -lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p -lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p -lib.ChangeSecureLocation.argtypes, lib.ChangeSecureLocation.restype = [ - c_char_p -], c_void_p -lib.SetSecureLocation.argtypes, lib.SetSecureLocation.restype = [ - c_char_p, - c_char_p, -], c_void_p -lib.SetConnected.argtypes, lib.SetConnected.restype = [c_char_p], c_void_p -lib.SetDisconnecting.argtypes, lib.SetDisconnecting.restype = [c_char_p], c_void_p -lib.SetConnecting.argtypes, lib.SetConnecting.restype = [c_char_p], c_void_p -lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c_void_p -lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p -lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int -lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p -lib.FreeProfiles.argtypes, lib.FreeProfiles.restype = [c_void_p], None -lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None -lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None -lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [ - c_void_p -], None -lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None -lib.FreeError.argtypes, lib.FreeError.restype = [c_void_p], None -lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None -lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None -lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int -lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], DataError - - -def encode_args(args, types): - for arg, t in zip(args, types): - # c_char_p needs the str to be encoded to bytes - if t is c_char_p: - arg = arg.encode("utf-8") - yield arg - - -def decode_res(t): - return decode_map.get(t, lambda x: x) - - -def get_ptr_string(ptr: c_void_p) -> str: - if ptr: - string = cast(ptr, c_char_p).value - lib.FreeString(ptr) - if string: - return string.decode("utf-8") - return "" - - -def get_ptr_list_strings( - strings: POINTER(c_char_p), total_strings: c_size_t -) -> List[str]: - if strings: - strings_list = [] - for i in range(total_strings): - strings_list.append(strings[i].decode("utf-8")) - return strings_list - return [] - -def get_error(ptr: c_void_p) -> Optional[WrappedError]: - if not ptr: - return None - err = cast(ptr, POINTER(cError)).contents - wrapped = WrappedError(err.traceback.decode("utf-8"), err.cause.decode("utf-8"), ErrorLevel(err.level)) - lib.FreeError(ptr) - return wrapped - -def get_config_error(config_error: ConfigError) -> Tuple[str, str, Optional[WrappedError]]: - config = get_ptr_string(config_error.config) - config_type = get_ptr_string(config_error.config_type) - err = get_error(config_error.error) - return config, config_type, err - -def get_data_error(data_error: DataError, data_conv=get_ptr_string) -> Tuple[str, Optional[WrappedError]]: - data = data_conv(data_error.data) - error = get_error(data_error.error) - return data, error - - -def get_bool(boolInt: c_int) -> bool: - return boolInt == 1 - - -decode_map = { - c_int: get_bool, - c_void_p: get_error, - DataError: get_data_error, - ConfigError: get_config_error, -} diff --git a/wrappers/python/src/discovery.py b/wrappers/python/src/discovery.py deleted file mode 100644 index a1b87ea..0000000 --- a/wrappers/python/src/discovery.py +++ /dev/null @@ -1,126 +0,0 @@ -from . import lib, cDiscoveryOrganizations, cDiscoveryServers, get_ptr_list_strings -from ctypes import cast, POINTER - - -class DiscoOrganization: - def __init__(self, display_name, org_id, secure_internet_home, keyword_list): - self.display_name = display_name - self.org_id = org_id - self.secure_internet_home = secure_internet_home - self.keyword_list = keyword_list - - def __str__(self): - return self.display_name - - -class DiscoOrganizations: - def __init__(self, version, organizations): - self.version = version - self.organizations = organizations - - -class DiscoServer: - def __init__( - self, - authentication_url_template, - base_url, - country_code, - display_name, - keyword_list, - public_keys, - server_type, - support_contacts, - ): - self.authentication_url_template = authentication_url_template - self.base_url = base_url - self.country_code = country_code - self.display_name = display_name - self.keyword_list = keyword_list - self.public_keys = public_keys - self.server_type = server_type - self.support_contacts = support_contacts - - def __str__(self): - return self.display_name - - -class DiscoServers: - def __init__(self, version, servers): - self.version = version - self.servers = servers - - -def get_disco_organization(ptr): - if not ptr: - return None - - current_organization = ptr.contents - display_name = current_organization.display_name.decode("utf-8") - org_id = current_organization.org_id.decode("utf-8") - secure_internet_home = current_organization.secure_internet_home.decode("utf-8") - keyword_list = current_organization.keyword_list.decode("utf-8") - return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list) - - -def get_disco_server(ptr): - if not ptr: - return None - - current_server = ptr.contents - authentication_url_template = current_server.authentication_url_template.decode( - "utf-8" - ) - base_url = current_server.base_url.decode("utf-8") - country_code = current_server.country_code.decode("utf-8") - display_name = current_server.display_name.decode("utf-8") - keyword_list = current_server.keyword_list.decode("utf-8") - public_keys = get_ptr_list_strings( - current_server.public_key_list, current_server.total_public_keys - ) - server_type = current_server.server_type.decode("utf-8") - support_contacts = get_ptr_list_strings( - current_server.support_contact, current_server.total_support_contact - ) - return DiscoServer( - authentication_url_template, - base_url, - country_code, - display_name, - keyword_list, - public_keys, - server_type, - support_contacts, - ) - - -def get_disco_servers(ptr): - if ptr: - svrs = cast(ptr, POINTER(cDiscoveryServers)).contents - - servers = [] - - if svrs.servers: - for i in range(svrs.total_servers): - current = get_disco_server(svrs.servers[i]) - - if current is None: - continue - servers.append(current) - lib.FreeDiscoServers(ptr) - return DiscoServers(svrs.version, servers) - return None - - -def get_disco_organizations(ptr): - if ptr: - orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents - organizations = [] - if orgs.organizations: - for i in range(orgs.total_organizations): - current = get_disco_organization(orgs.organizations[i]) - if current is None: - continue - organizations.append(current) - lib.FreeDiscoOrganizations(ptr) - return DiscoOrganizations(orgs.version, organizations) - return None diff --git a/wrappers/python/src/error.py b/wrappers/python/src/error.py deleted file mode 100644 index 50298bb..0000000 --- a/wrappers/python/src/error.py +++ /dev/null @@ -1,15 +0,0 @@ -from enum import Enum - -class ErrorLevel(Enum): - ERR_OTHER = 0 - ERR_INFO = 1 - ERR_WARNING = 2 - ERR_FATAL = 3 - -class WrappedError(Exception): - def __init__(self, traceback: str, cause: str, level: ErrorLevel): - super(WrappedError, self).__init__(cause) - self.traceback = traceback - self.cause = cause - self.level = level - diff --git a/wrappers/python/src/event.py b/wrappers/python/src/event.py deleted file mode 100644 index cf1a9d0..0000000 --- a/wrappers/python/src/event.py +++ /dev/null @@ -1,109 +0,0 @@ -from . import VPNStateChange, get_ptr_string -from enum import Enum -from typing import Callable -from .state import State, StateType -from .server import ( - get_locations, - get_transition_profiles, - get_transition_server, - get_servers, -) - - -EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback" - -# A state transition decorator for classes -# To use this, make sure to register the class with `register_class_callbacks` -def class_state_transition(state: int, state_type: StateType) -> Callable: - def wrapper(func): - setattr(func, EDUVPN_CALLBACK_PROPERTY, (state, state_type)) - return func - - return wrapper - - -def convert_data(state: State, data): - if not data: - return None - if state is State.NO_SERVER: - return get_servers(data) - if state is State.OAUTH_STARTED: - return get_ptr_string(data) - if state is State.ASK_LOCATION: - return get_locations(data) - if state is State.ASK_PROFILE: - return get_transition_profiles(data) - if state in [ - State.DISCONNECTED, - State.DISCONNECTING, - State.CONNECTING, - State.CONNECTED, - ]: - return get_transition_server(data) - - -class EventHandler(object): - def __init__(self): - self.handlers = {} - - def change_class_callbacks(self, cls, add=True) -> None: - # Loop over method names - for method_name in dir(cls): - try: - # Get the method - method = getattr(cls, method_name) - except: - # Unable to get a value, go to the next - continue - - # If it has a callback defined, add it to the events - method_value = getattr(method, EDUVPN_CALLBACK_PROPERTY, None) - if method_value: - state, state_type = method_value - - if add: - self.add_event(state, state_type, method) - else: - self.remove_event(state, state_type, method) - - def remove_event(self, state: int, state_type: StateType, func: Callable): - for key, values in self.handlers.copy().items(): - if key == (state, state_type): - values.remove(func) - if not values: - del self.handlers[key] - else: - self.handlers[key] = values - - def add_event(self, state: int, state_type: StateType, func: Callable): - if (state, state_type) not in self.handlers: - self.handlers[(state, state_type)] = [] - self.handlers[(state, state_type)].append(func) - - # A decorator for standalone functions - def on(self, state: int, state_type: StateType) -> Callable: - def wrapped_f(func): - self.add_event(state, state_type, func) - return func - - return wrapped_f - - def run_state( - self, state: int, other_state: int, state_type: StateType, data: str - ) -> None: - if (state, state_type) not in self.handlers: - return - for func in self.handlers[(state, state_type)]: - func(other_state, data) - - def run( - self, old_state: int, new_state: int, data: str, convert: bool = True - ) -> None: - # First run leave transitions, then enter - # The state is done when the wait event finishes - converted = data - if convert: - converted = convert_data(new_state, data) - self.run_state(old_state, new_state, StateType.Leave, converted) - self.run_state(new_state, old_state, StateType.Enter, converted) - self.run_state(new_state, old_state, StateType.Wait, converted) diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py deleted file mode 100644 index 01621ae..0000000 --- a/wrappers/python/src/main.py +++ /dev/null @@ -1,253 +0,0 @@ -from . import lib, VPNStateChange, encode_args, decode_res, get_data_error -from typing import Optional, Tuple -import threading -from .discovery import get_disco_organizations, get_disco_servers -from .event import EventHandler -from .state import State, StateType -from .server import get_servers - -eduvpn_objects = {} - - -def add_as_global_object(eduvpn) -> bool: - global eduvpn_objects - if eduvpn.name not in eduvpn_objects: - eduvpn_objects[eduvpn.name] = eduvpn - return True - return False - - -def remove_as_global_object(eduvpn): - global eduvpn_objects - eduvpn_objects.pop(eduvpn.name, None) - - -@VPNStateChange -def state_callback(name, old_state, new_state, data): - name = name.decode() - if name not in eduvpn_objects: - return - eduvpn_objects[name].callback(State(old_state), State(new_state), data) - - -class EduVPN(object): - def __init__(self, name: str, config_directory: str): - self.event_handler = EventHandler() - self.name = name - self.config_directory = config_directory - - # Callbacks that need to wait for specific events - - # The ask profile callback needs to wait for the UI thread to select a profile - # This is stored in the profile_event - self.profile_event: Optional[threading.Event] = None - self.location_event: Optional[threading.Event] = None - - @self.event.on(State.ASK_PROFILE, StateType.Wait) - def wait_profile_event(old_state: int, profiles: str): - if self.profile_event: - self.profile_event.wait() - - @self.event.on(State.ASK_LOCATION, StateType.Wait) - def wait_location_event(old_state: int, locations: str): - if self.location_event: - self.location_event.wait() - - def go_function(self, func, *args): - # The functions all have at least one arg type which is the name of the client - args_gen = encode_args(list(args), func.argtypes[1:]) - res = func(self.name.encode("utf-8"), *(args_gen)) - return decode_res(func.restype)(res) - - def go_function_custom_decode(self, func, decode_func, *args): - # The functions all have at least one arg type which is the name of the client - args_gen = encode_args(list(args), func.argtypes[1:]) - res = func(self.name.encode("utf-8"), *(args_gen)) - return decode_func(res) - - def cancel_oauth(self) -> None: - cancel_oauth_err = self.go_function(lib.CancelOAuth) - - if cancel_oauth_err: - raise cancel_oauth_err - - def deregister(self) -> None: - self.go_function(lib.Deregister) - remove_as_global_object(self) - - def register(self, debug: bool = False) -> None: - if not add_as_global_object(self): - raise Exception("Already registered") - - register_err = self.go_function( - lib.Register, self.config_directory, state_callback, debug - ) - - if register_err: - raise register_err - - def get_disco_servers(self) -> str: - servers, servers_err = self.go_function_custom_decode( - lib.GetDiscoServers, decode_func=lambda x: get_data_error(x, get_disco_servers) - ) - - if servers_err: - raise servers_err - - return servers - - def get_disco_organizations(self) -> str: - organizations, organizations_err = self.go_function_custom_decode( - lib.GetDiscoOrganizations, decode_func=lambda x: get_data_error(x, get_disco_organizations) - ) - - if organizations_err: - raise organizations_err - - return organizations - - def remove_secure_internet(self): - remove_err = self.go_function(lib.RemoveSecureInternet) - - if remove_err: - raise remove_err - - def remove_institute_access(self, url: str): - remove_err = self.go_function(lib.RemoveInstituteAccess, url) - - if remove_err: - raise remove_err - - def remove_custom_server(self, url: str): - remove_err = self.go_function(lib.RemoveCustomServer, url) - - if remove_err: - raise remove_err - - def get_config(self, url: str, func: callable, force_tcp: bool = False): - # Because it could be the case that a profile callback is started, store a threading event - # In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set - # The event is set in self.set_profile - self.profile_event = threading.Event() - - config, config_type, config_err = self.go_function(func, url, force_tcp) - - self.profile_event = None - self.location_event = None - - if config_err: - raise config_err - - return config, config_type - - def get_config_custom_server( - self, url: str, force_tcp: bool = False - ) -> Tuple[str, str]: - return self.get_config(url, lib.GetConfigCustomServer, force_tcp) - - def get_config_institute_access( - self, url: str, force_tcp: bool = False - ) -> Tuple[str, str]: - return self.get_config(url, lib.GetConfigInstituteAccess, force_tcp) - - def get_config_secure_internet( - self, url: str, force_tcp: bool = False - ) -> Tuple[str, str]: - self.location_event = threading.Event() - return self.get_config(url, lib.GetConfigSecureInternet, force_tcp) - - def go_back(self) -> None: - # Ignore the error - self.go_function(lib.GoBack) - - def set_connected(self) -> None: - connect_err = self.go_function(lib.SetConnected) - - if connect_err: - raise connect_err - - def set_disconnecting(self) -> None: - disconnecting_err = self.go_function(lib.SetDisconnecting) - - if disconnecting_err: - raise disconnecting_err - - def set_connecting(self) -> None: - connecting_err = self.go_function(lib.SetConnecting) - - if connecting_err: - raise connecting_err - - def set_disconnected(self, cleanup=True) -> None: - disconnect_err = self.go_function(lib.SetDisconnected, cleanup) - - if disconnect_err: - raise disconnect_err - - def set_search_server(self) -> None: - search_err = self.go_function(lib.SetSearchServer) - - if search_err: - raise search_err - - def remove_class_callbacks(self, cls) -> None: - self.event_handler.change_class_callbacks(cls, add=False) - - def register_class_callbacks(self, cls) -> None: - self.event_handler.change_class_callbacks(cls) - - @property - def event(self) -> EventHandler: - return self.event_handler - - def callback(self, old_state: State, new_state: State, data) -> None: - self.event.run(old_state, new_state, data) - - def set_profile(self, profile_id: str) -> None: - # Set the profile id - profile_err = self.go_function(lib.SetProfileID, profile_id) - - # If there is a profile event, set it so that the wait callback finishes - # And so that the Go code can move to the next state - if self.profile_event: - self.profile_event.set() - - if profile_err: - raise profile_err - - def change_secure_location(self) -> None: - # Set the location by country code - self.location_event = threading.Event() - location_err = self.go_function(lib.ChangeSecureLocation) - - if location_err: - raise location_err - - def set_secure_location(self, country_code: str) -> None: - # Set the location by country code - location_err = self.go_function(lib.SetSecureLocation, country_code) - - # If there is a location event, set it so that the wait callback finishes - # And so that the Go code can move to the next state - if self.location_event: - self.location_event.set() - - if location_err: - raise location_err - - def renew_session(self) -> None: - renew_err = self.go_function(lib.RenewSession) - - if renew_err: - raise renew_err - - def should_renew_button(self) -> bool: - return self.go_function(lib.ShouldRenewButton) - - def in_fsm_state(self, state_id: State) -> bool: - return self.go_function(lib.InFSMState, state_id) - - def get_saved_servers(self) -> str: - return self.go_function_custom_decode( - lib.GetSavedServers, decode_func=lambda x: get_data_error(x, get_servers) - ) diff --git a/wrappers/python/src/server.py b/wrappers/python/src/server.py deleted file mode 100644 index dce5d51..0000000 --- a/wrappers/python/src/server.py +++ /dev/null @@ -1,175 +0,0 @@ -from . import lib, cServer, cServers, cServerLocations, cServerProfiles -from ctypes import cast, POINTER, c_char_p -from datetime import datetime - - -class Profile: - def __init__(self, identifier, display_name, default_gateway: bool): - self.identifier = identifier - self.display_name = display_name - self.default_gateway = default_gateway - - def __str__(self): - return self.display_name - - -class Profiles: - def __init__(self, profiles, current): - self.profiles = profiles - self.current_index = current - - @property - def current(self): - if self.current_index < len(self.profiles): - return self.profiles[self.current_index] - return None - - -class Server: - def __init__(self, url, display_name, profiles=None, expire_time=0): - self.url = url - self.display_name = display_name - self.profiles = profiles - self.current_profile = None - self.expire_time = datetime.fromtimestamp(expire_time) - - def __str__(self): - return self.display_name - - -class InstituteServer(Server): - def __init__(self, url, display_name, support_contact, profiles, expire_time): - super().__init__(url, display_name, profiles, expire_time) - self.support_contact = support_contact - - -class SecureInternetServer(Server): - def __init__( - self, - org_id, - display_name, - support_contact, - profiles, - expire_time, - country_code, - ): - super().__init__(org_id, display_name, profiles, expire_time) - self.org_id = org_id - self.support_contact = support_contact - self.country_code = country_code - - -def get_type_for_str(type_str: str): - if type_str == "secure_internet": - return SecureInternetServer - if type_str == "custom_server": - return Server - return InstituteServer - - -def get_profiles(ptr): - if not ptr: - return [] - profiles = [] - _profiles = ptr.contents - current_profile = _profiles.current - if not _profiles.profiles: - return [] - for i in range(_profiles.total_profiles): - if not _profiles.profiles[i]: - continue - profile = _profiles.profiles[i].contents - profiles.append( - Profile( - profile.identifier.decode("utf-8"), - profile.display_name.decode("utf-8"), - profile.default_gateway == 1, - ) - ) - return Profiles(profiles, current_profile) - - -def get_server(ptr, _type=None): - if not ptr: - return None - - current_server = ptr.contents - if _type is None: - _type = get_type_for_str(current_server.server_type.decode("utf-8")) - - identifier = current_server.identifier.decode("utf-8") - display_name = current_server.display_name.decode("utf-8") - - if _type is not Server: - support_contact = [] - for i in range(current_server.total_support_contact): - support_contact.append(current_server.support_contact[i].decode("utf-8")) - profiles = get_profiles(current_server.profiles) - if _type is SecureInternetServer: - return SecureInternetServer( - identifier, - display_name, - support_contact, - profiles, - current_server.expire_time, - current_server.country_code.decode("utf-8"), - ) - if _type is InstituteServer: - return InstituteServer( - identifier, - display_name, - support_contact, - profiles, - current_server.expire_time, - ) - return Server(identifier, display_name, profiles, current_server.expire_time) - - -def get_transition_server(ptr): - server = get_server(cast(ptr, POINTER(cServer))) - lib.FreeServer(ptr) - return server - - -def get_transition_profiles(ptr): - profiles = get_profiles(cast(ptr, POINTER(cServerProfiles))) - lib.FreeProfiles(ptr) - return profiles - - -def get_servers(ptr): - if ptr: - returned = [] - servers = cast(ptr, POINTER(cServers)).contents - if servers.custom_servers: - for i in range(servers.total_custom): - current = get_server(servers.custom_servers[i], Server) - if current is None: - continue - returned.append(current) - - if servers.institute_servers: - for i in range(servers.total_institute): - current = get_server(servers.institute_servers[i], InstituteServer) - if current is None: - continue - returned.append(current) - - if servers.secure_internet: - current = get_server(servers.secure_internet, SecureInternetServer) - if current is not None: - returned.append(current) - lib.FreeServers(ptr) - return returned - return None - - -def get_locations(ptr): - if ptr: - locations = cast(ptr, POINTER(cServerLocations)).contents - location_list = [] - for i in range(locations.total_locations): - location_list.append(locations.locations[i].decode("utf-8")) - lib.FreeSecureLocations(ptr) - return location_list - return None diff --git a/wrappers/python/src/state.py b/wrappers/python/src/state.py deleted file mode 100644 index 5af004f..0000000 --- a/wrappers/python/src/state.py +++ /dev/null @@ -1,24 +0,0 @@ -from enum import IntEnum - - -class StateType(IntEnum): - Enter = 1 - Leave = 2 - Wait = 3 - - -class State(IntEnum): - DEREGISTERED = 0 - NO_SERVER = 1 - ASK_LOCATION = 2 - SEARCH_SERVER = 3 - LOADING_SERVER = 4 - CHOSEN_SERVER = 5 - OAUTH_STARTED = 6 - AUTHORIZED = 7 - REQUEST_CONFIG = 8 - ASK_PROFILE = 9 - DISCONNECTED = 10 - DISCONNECTING = 11 - CONNECTING = 12 - CONNECTED = 13 -- cgit v1.2.3