diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-26 17:36:30 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-27 10:53:37 +0200 |
| commit | 09ec69dfdef409868f1cb39cb8cc4b33c8690c9f (patch) | |
| tree | 109925dbbee4a9120211897582760f96010ae8f2 /wrappers/python/eduvpn_common | |
| parent | 0a19c2dedcaaa177b420eac99149515d84508204 (diff) | |
Python: Reformat and move most loading out of __init__
Diffstat (limited to 'wrappers/python/eduvpn_common')
| -rw-r--r-- | wrappers/python/eduvpn_common/__init__.py | 270 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/discovery.py | 18 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/error.py | 3 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/event.py | 19 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/loader.py | 113 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/main.py | 73 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/server.py | 12 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/types.py | 175 |
8 files changed, 359 insertions, 324 deletions
diff --git a/wrappers/python/eduvpn_common/__init__.py b/wrappers/python/eduvpn_common/__init__.py index 1406fa2..e69de29 100644 --- a/wrappers/python/eduvpn_common/__init__.py +++ b/wrappers/python/eduvpn_common/__init__.py @@ -1,270 +0,0 @@ -from ctypes import * -from collections import defaultdict -import pathlib -import platform -from typing import Tuple, Optional -from typing import List -from eduvpn_common.error import WrappedError, ErrorLevel - -_lib_prefixes = defaultdict( - lambda: "lib", - { - "windows": "", - }, -) - -_lib_suffixes = defaultdict( - lambda: ".so", - { - "windows": ".dll", - "darwin": ".dylib", - }, -) - -_os = platform.system().lower() - -_libname = "eduvpn_common" -_libfile = f"{_lib_prefixes[_os]}{_libname}{_lib_suffixes[_os]}" - -lib = None - -# Try to load in the normal path -try: - lib = cdll.LoadLibrary(_libfile) -# Otherwise, library should have been copied to the lib/ folder -except: - lib = cdll.LoadLibrary(str(pathlib.Path(__file__).parent / "lib" / _libfile)) - - -class cError(Structure): - _fields_ = [ - ("level", c_int), - ("traceback", c_char_p), - ("cause", c_char_p), - ] - -class cServerLocations(Structure): - _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)] - - -class cDiscoveryOrganization(Structure): - _fields_ = [ - ("display_name", c_char_p), - ("org_id", c_char_p), - ("secure_internet_home", c_char_p), - ("keyword_list", c_char_p), - ] - - -class cDiscoveryOrganizations(Structure): - _fields_ = [ - ("version", c_ulonglong), - ("organizations", POINTER(POINTER(cDiscoveryOrganization))), - ("total_organizations", c_size_t), - ] - - -class cDiscoveryServer(Structure): - _fields_ = [ - ("authentication_url_template", c_char_p), - ("base_url", c_char_p), - ("country_code", c_char_p), - ("display_name", c_char_p), - ("keyword_list", c_char_p), - ("public_key_list", POINTER(c_char_p)), - ("total_public_keys", c_size_t), - ("server_type", c_char_p), - ("support_contact", POINTER(c_char_p)), - ("total_support_contact", c_size_t), - ] - - -class cDiscoveryServers(Structure): - _fields_ = [ - ("version", c_ulonglong), - ("servers", POINTER(POINTER(cDiscoveryServer))), - ("total_servers", c_size_t), - ] - - -class cServerProfile(Structure): - _fields_ = [ - ("identifier", c_char_p), - ("display_name", c_char_p), - ("default_gateway", c_int), - ] - - -class cServerProfiles(Structure): - _fields_ = [ - ("current", c_int), - ("profiles", POINTER(POINTER(cServerProfile))), - ("total_profiles", c_size_t), - ] - - -class cServer(Structure): - _fields_ = [ - ("identifier", c_char_p), - ("display_name", c_char_p), - ("server_type", c_char_p), - ("country_code", c_char_p), - ("support_contact", POINTER(c_char_p)), - ("total_support_contact", c_size_t), - ("profiles", POINTER(cServerProfiles)), - ("expire_time", c_ulonglong), - ] - - -class cServers(Structure): - _fields_ = [ - ("custom_servers", POINTER(POINTER(cServer))), - ("total_custom", c_size_t), - ("institute_servers", POINTER(POINTER(cServer))), - ("total_institute", c_size_t), - ("secure_internet", POINTER(cServer)), - ] - - -class DataError(Structure): - _fields_ = [("data", c_void_p), ("error", c_void_p)] - - -class ConfigError(Structure): - _fields_ = [("config", c_void_p), ("config_type", c_void_p), ("error", c_void_p)] - - -VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) - -# Exposed functions -# We have to use c_void_p instead of c_char_p to free it properly -# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error -lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [ - c_char_p -], c_void_p -lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [ - c_char_p, - c_char_p, -], c_void_p -lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [ - c_char_p, - c_char_p, -], c_void_p -lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [ - c_char_p, - c_char_p, - c_int, -], ConfigError -lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [ - c_char_p, - c_char_p, - c_int, -], ConfigError -lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [ - c_char_p, - c_char_p, - c_int, -], ConfigError -lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None -lib.Register.argtypes, lib.Register.restype = [ - c_char_p, - c_char_p, - VPNStateChange, - c_int, -], c_void_p -lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [ - c_char_p -], DataError -lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError -lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None -lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p -lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p -lib.ChangeSecureLocation.argtypes, lib.ChangeSecureLocation.restype = [ - c_char_p -], c_void_p -lib.SetSecureLocation.argtypes, lib.SetSecureLocation.restype = [ - c_char_p, - c_char_p, -], c_void_p -lib.SetConnected.argtypes, lib.SetConnected.restype = [c_char_p], c_void_p -lib.SetDisconnecting.argtypes, lib.SetDisconnecting.restype = [c_char_p], c_void_p -lib.SetConnecting.argtypes, lib.SetConnecting.restype = [c_char_p], c_void_p -lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c_void_p -lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p -lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int -lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p -lib.FreeProfiles.argtypes, lib.FreeProfiles.restype = [c_void_p], None -lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None -lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None -lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [ - c_void_p -], None -lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None -lib.FreeError.argtypes, lib.FreeError.restype = [c_void_p], None -lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None -lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None -lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int -lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], DataError - - -def encode_args(args, types): - for arg, t in zip(args, types): - # c_char_p needs the str to be encoded to bytes - if t is c_char_p: - arg = arg.encode("utf-8") - yield arg - - -def decode_res(t): - return decode_map.get(t, lambda x: x) - - -def get_ptr_string(ptr: c_void_p) -> str: - if ptr: - string = cast(ptr, c_char_p).value - lib.FreeString(ptr) - if string: - return string.decode("utf-8") - return "" - - -def get_ptr_list_strings( - strings: POINTER(c_char_p), total_strings: c_size_t -) -> List[str]: - if strings: - strings_list = [] - for i in range(total_strings): - strings_list.append(strings[i].decode("utf-8")) - return strings_list - return [] - -def get_error(ptr: c_void_p) -> Optional[WrappedError]: - if not ptr: - return None - err = cast(ptr, POINTER(cError)).contents - wrapped = WrappedError(err.traceback.decode("utf-8"), err.cause.decode("utf-8"), ErrorLevel(err.level)) - lib.FreeError(ptr) - return wrapped - -def get_config_error(config_error: ConfigError) -> Tuple[str, str, Optional[WrappedError]]: - config = get_ptr_string(config_error.config) - config_type = get_ptr_string(config_error.config_type) - err = get_error(config_error.error) - return config, config_type, err - -def get_data_error(data_error: DataError, data_conv=get_ptr_string) -> Tuple[str, Optional[WrappedError]]: - data = data_conv(data_error.data) - error = get_error(data_error.error) - return data, error - - -def get_bool(boolInt: c_int) -> bool: - return boolInt == 1 - - -decode_map = { - c_int: get_bool, - c_void_p: get_error, - DataError: get_data_error, - ConfigError: get_config_error, -} diff --git a/wrappers/python/eduvpn_common/discovery.py b/wrappers/python/eduvpn_common/discovery.py index 68741bc..3a1cbe6 100644 --- a/wrappers/python/eduvpn_common/discovery.py +++ b/wrappers/python/eduvpn_common/discovery.py @@ -1,4 +1,8 @@ -from eduvpn_common import lib, cDiscoveryOrganizations, cDiscoveryServers, get_ptr_list_strings +from eduvpn_common.types import ( + cDiscoveryOrganizations, + cDiscoveryServers, + get_ptr_list_strings, +) from ctypes import cast, POINTER @@ -62,7 +66,7 @@ def get_disco_organization(ptr): return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list) -def get_disco_server(ptr): +def get_disco_server(lib, ptr): if not ptr: return None @@ -75,11 +79,11 @@ def get_disco_server(ptr): display_name = current_server.display_name.decode("utf-8") keyword_list = current_server.keyword_list.decode("utf-8") public_keys = get_ptr_list_strings( - current_server.public_key_list, current_server.total_public_keys + lib, current_server.public_key_list, current_server.total_public_keys ) server_type = current_server.server_type.decode("utf-8") support_contacts = get_ptr_list_strings( - current_server.support_contact, current_server.total_support_contact + lib, current_server.support_contact, current_server.total_support_contact ) return DiscoServer( authentication_url_template, @@ -93,7 +97,7 @@ def get_disco_server(ptr): ) -def get_disco_servers(ptr): +def get_disco_servers(lib, ptr): if ptr: svrs = cast(ptr, POINTER(cDiscoveryServers)).contents @@ -101,7 +105,7 @@ def get_disco_servers(ptr): if svrs.servers: for i in range(svrs.total_servers): - current = get_disco_server(svrs.servers[i]) + current = get_disco_server(lib, svrs.servers[i]) if current is None: continue @@ -111,7 +115,7 @@ def get_disco_servers(ptr): return None -def get_disco_organizations(ptr): +def get_disco_organizations(lib, ptr): if ptr: orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents organizations = [] diff --git a/wrappers/python/eduvpn_common/error.py b/wrappers/python/eduvpn_common/error.py index 50298bb..a5b59b4 100644 --- a/wrappers/python/eduvpn_common/error.py +++ b/wrappers/python/eduvpn_common/error.py @@ -1,15 +1,16 @@ from enum import Enum + class ErrorLevel(Enum): ERR_OTHER = 0 ERR_INFO = 1 ERR_WARNING = 2 ERR_FATAL = 3 + class WrappedError(Exception): def __init__(self, traceback: str, cause: str, level: ErrorLevel): super(WrappedError, self).__init__(cause) self.traceback = traceback self.cause = cause self.level = level - diff --git a/wrappers/python/eduvpn_common/event.py b/wrappers/python/eduvpn_common/event.py index 4532bef..a47a0a7 100644 --- a/wrappers/python/eduvpn_common/event.py +++ b/wrappers/python/eduvpn_common/event.py @@ -1,4 +1,3 @@ -from eduvpn_common import VPNStateChange, get_ptr_string from enum import Enum from typing import Callable from eduvpn_common.state import State, StateType @@ -8,6 +7,7 @@ from eduvpn_common.server import ( get_transition_server, get_servers, ) +from eduvpn_common.types import get_ptr_string EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback" @@ -22,29 +22,30 @@ def class_state_transition(state: int, state_type: StateType) -> Callable: return wrapper -def convert_data(state: State, data): +def convert_data(lib, state: State, data): if not data: return None if state is State.NO_SERVER: - return get_servers(data) + return get_servers(lib, data) if state is State.OAUTH_STARTED: - return get_ptr_string(data) + return get_ptr_string(lib, data) if state is State.ASK_LOCATION: - return get_locations(data) + return get_locations(lib, data) if state is State.ASK_PROFILE: - return get_transition_profiles(data) + return get_transition_profiles(lib, data) if state in [ State.DISCONNECTED, State.DISCONNECTING, State.CONNECTING, State.CONNECTED, ]: - return get_transition_server(data) + return get_transition_server(lib, data) class EventHandler(object): - def __init__(self): + def __init__(self, lib): self.handlers = {} + self.lib = lib def change_class_callbacks(self, cls, add=True) -> None: # Loop over method names @@ -103,7 +104,7 @@ class EventHandler(object): # The state is done when the wait event finishes converted = data if convert: - converted = convert_data(new_state, data) + converted = convert_data(self.lib, new_state, data) self.run_state(old_state, new_state, StateType.Leave, converted) self.run_state(new_state, old_state, StateType.Enter, converted) self.run_state(new_state, old_state, StateType.Wait, converted) diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py new file mode 100644 index 0000000..bce2638 --- /dev/null +++ b/wrappers/python/eduvpn_common/loader.py @@ -0,0 +1,113 @@ +from ctypes import * +from collections import defaultdict +import pathlib +import platform +from eduvpn_common.types import * + + +def load_lib(version: str): + lib_prefixes = defaultdict( + lambda: "lib", + { + "windows": "", + }, + ) + + lib_suffixes = defaultdict( + lambda: ".so", + { + "windows": ".dll", + "darwin": ".dylib", + }, + ) + + os = platform.system().lower() + + libname = "eduvpn_common" + libfile = f"{lib_prefixes[os]}{libname}{lib_suffixes[os]}" + + lib = None + + # Try to load in the normal path + try: + lib = cdll.LoadLibrary(libfile) + # Otherwise, library should have been copied to the lib/ folder + except: + lib = cdll.LoadLibrary(str(pathlib.Path(__file__).parent / "lib" / libfile)) + + return lib + + +def initialize_functions(lib): + # Exposed functions + # We have to use c_void_p instead of c_char_p to free it properly + # See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error + lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p + lib.ChangeSecureLocation.argtypes, lib.ChangeSecureLocation.restype = [ + c_char_p + ], c_void_p + lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None + lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [ + c_void_p + ], None + lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None + lib.FreeError.argtypes, lib.FreeError.restype = [c_void_p], None + lib.FreeProfiles.argtypes, lib.FreeProfiles.restype = [c_void_p], None + lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None + lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None + lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None + lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None + lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [ + c_char_p, + c_char_p, + c_int, + ], ConfigError + lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [ + c_char_p, + c_char_p, + c_int, + ], ConfigError + lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [ + c_char_p, + c_char_p, + c_int, + ], ConfigError + lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [ + c_char_p + ], DataError + lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError + lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], DataError + lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None + lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int + lib.Register.argtypes, lib.Register.restype = [ + c_char_p, + c_char_p, + VPNStateChange, + c_int, + ], c_void_p + lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [ + c_char_p, + c_char_p, + ], c_void_p + lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [ + c_char_p, + c_char_p, + ], c_void_p + lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [ + c_char_p + ], c_void_p + lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p + lib.SetConnected.argtypes, lib.SetConnected.restype = [c_char_p], c_void_p + lib.SetConnecting.argtypes, lib.SetConnecting.restype = [c_char_p], c_void_p + lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [ + c_char_p, + c_int, + ], c_void_p + lib.SetDisconnecting.argtypes, lib.SetDisconnecting.restype = [c_char_p], c_void_p + lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p + lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p + lib.SetSecureLocation.argtypes, lib.SetSecureLocation.restype = [ + c_char_p, + c_char_p, + ], c_void_p + lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index 3875ad9..1b18fb0 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -1,13 +1,16 @@ -from eduvpn_common import lib, VPNStateChange, encode_args, decode_res, get_data_error from typing import Optional, Tuple import threading from eduvpn_common.discovery import get_disco_organizations, get_disco_servers from eduvpn_common.event import EventHandler -from eduvpn_common.state import State, StateType +from eduvpn_common.loader import initialize_functions, load_lib +from eduvpn_common.types import VPNStateChange, encode_args, decode_res, get_data_error from eduvpn_common.server import get_servers +from eduvpn_common.state import State, StateType eduvpn_objects = {} +VERSION = "0.1.0" + def add_as_global_object(eduvpn) -> bool: global eduvpn_objects @@ -32,10 +35,15 @@ def state_callback(name, old_state, new_state, data): class EduVPN(object): def __init__(self, name: str, config_directory: str): - self.event_handler = EventHandler() self.name = name self.config_directory = config_directory + # Load the library + self.lib = load_lib(VERSION) + initialize_functions(self.lib) + + self.event_handler = EventHandler(self.lib) + # Callbacks that need to wait for specific events # The ask profile callback needs to wait for the UI thread to select a profile @@ -57,22 +65,22 @@ class EduVPN(object): # The functions all have at least one arg type which is the name of the client args_gen = encode_args(list(args), func.argtypes[1:]) res = func(self.name.encode("utf-8"), *(args_gen)) - return decode_res(func.restype)(res) + return decode_res(func.restype)(self.lib, res) def go_function_custom_decode(self, func, decode_func, *args): # The functions all have at least one arg type which is the name of the client args_gen = encode_args(list(args), func.argtypes[1:]) res = func(self.name.encode("utf-8"), *(args_gen)) - return decode_func(res) + return decode_func(self.lib, res) def cancel_oauth(self) -> None: - cancel_oauth_err = self.go_function(lib.CancelOAuth) + cancel_oauth_err = self.go_function(self.lib.CancelOAuth) if cancel_oauth_err: raise cancel_oauth_err def deregister(self) -> None: - self.go_function(lib.Deregister) + self.go_function(self.lib.Deregister) remove_as_global_object(self) def register(self, debug: bool = False) -> None: @@ -80,7 +88,7 @@ class EduVPN(object): raise Exception("Already registered") register_err = self.go_function( - lib.Register, self.config_directory, state_callback, debug + self.lib.Register, self.config_directory, state_callback, debug ) if register_err: @@ -88,38 +96,40 @@ class EduVPN(object): def get_disco_servers(self) -> str: servers, servers_err = self.go_function_custom_decode( - lib.GetDiscoServers, decode_func=lambda x: get_data_error(x, get_disco_servers) + self.lib.GetDiscoServers, + decode_func=lambda lib, x: get_data_error(lib, x, get_disco_servers), ) if servers_err: - raise servers_err + raise servers_err return servers def get_disco_organizations(self) -> str: organizations, organizations_err = self.go_function_custom_decode( - lib.GetDiscoOrganizations, decode_func=lambda x: get_data_error(x, get_disco_organizations) + self.lib.GetDiscoOrganizations, + decode_func=lambda lib, x: get_data_error(lib, x, get_disco_organizations), ) if organizations_err: - raise organizations_err + raise organizations_err return organizations def remove_secure_internet(self): - remove_err = self.go_function(lib.RemoveSecureInternet) + remove_err = self.go_function(self.lib.RemoveSecureInternet) if remove_err: raise remove_err def remove_institute_access(self, url: str): - remove_err = self.go_function(lib.RemoveInstituteAccess, url) + remove_err = self.go_function(self.lib.RemoveInstituteAccess, url) if remove_err: raise remove_err def remove_custom_server(self, url: str): - remove_err = self.go_function(lib.RemoveCustomServer, url) + remove_err = self.go_function(self.lib.RemoveCustomServer, url) if remove_err: raise remove_err @@ -143,49 +153,49 @@ class EduVPN(object): def get_config_custom_server( self, url: str, force_tcp: bool = False ) -> Tuple[str, str]: - return self.get_config(url, lib.GetConfigCustomServer, force_tcp) + return self.get_config(url, self.lib.GetConfigCustomServer, force_tcp) def get_config_institute_access( self, url: str, force_tcp: bool = False ) -> Tuple[str, str]: - return self.get_config(url, lib.GetConfigInstituteAccess, force_tcp) + return self.get_config(url, self.lib.GetConfigInstituteAccess, force_tcp) def get_config_secure_internet( self, url: str, force_tcp: bool = False ) -> Tuple[str, str]: self.location_event = threading.Event() - return self.get_config(url, lib.GetConfigSecureInternet, force_tcp) + return self.get_config(url, self.lib.GetConfigSecureInternet, force_tcp) def go_back(self) -> None: # Ignore the error - self.go_function(lib.GoBack) + self.go_function(self.lib.GoBack) def set_connected(self) -> None: - connect_err = self.go_function(lib.SetConnected) + connect_err = self.go_function(self.lib.SetConnected) if connect_err: raise connect_err def set_disconnecting(self) -> None: - disconnecting_err = self.go_function(lib.SetDisconnecting) + disconnecting_err = self.go_function(self.lib.SetDisconnecting) if disconnecting_err: raise disconnecting_err def set_connecting(self) -> None: - connecting_err = self.go_function(lib.SetConnecting) + connecting_err = self.go_function(self.lib.SetConnecting) if connecting_err: raise connecting_err def set_disconnected(self, cleanup=True) -> None: - disconnect_err = self.go_function(lib.SetDisconnected, cleanup) + disconnect_err = self.go_function(self.lib.SetDisconnected, cleanup) if disconnect_err: raise disconnect_err def set_search_server(self) -> None: - search_err = self.go_function(lib.SetSearchServer) + search_err = self.go_function(self.lib.SetSearchServer) if search_err: raise search_err @@ -205,7 +215,7 @@ class EduVPN(object): def set_profile(self, profile_id: str) -> None: # Set the profile id - profile_err = self.go_function(lib.SetProfileID, profile_id) + profile_err = self.go_function(self.lib.SetProfileID, profile_id) # If there is a profile event, set it so that the wait callback finishes # And so that the Go code can move to the next state @@ -218,14 +228,14 @@ class EduVPN(object): def change_secure_location(self) -> None: # Set the location by country code self.location_event = threading.Event() - location_err = self.go_function(lib.ChangeSecureLocation) + location_err = self.go_function(self.lib.ChangeSecureLocation) if location_err: raise location_err def set_secure_location(self, country_code: str) -> None: # Set the location by country code - location_err = self.go_function(lib.SetSecureLocation, country_code) + location_err = self.go_function(self.lib.SetSecureLocation, country_code) # If there is a location event, set it so that the wait callback finishes # And so that the Go code can move to the next state @@ -236,18 +246,19 @@ class EduVPN(object): raise location_err def renew_session(self) -> None: - renew_err = self.go_function(lib.RenewSession) + renew_err = self.go_function(self.lib.RenewSession) if renew_err: raise renew_err def should_renew_button(self) -> bool: - return self.go_function(lib.ShouldRenewButton) + return self.go_function(self.lib.ShouldRenewButton) def in_fsm_state(self, state_id: State) -> bool: - return self.go_function(lib.InFSMState, state_id) + return self.go_function(self.lib.InFSMState, state_id) def get_saved_servers(self) -> str: return self.go_function_custom_decode( - lib.GetSavedServers, decode_func=lambda x: get_data_error(x, get_servers) + self.lib.GetSavedServers, + decode_func=lambda lib, x: get_data_error(lib, x, get_servers), ) diff --git a/wrappers/python/eduvpn_common/server.py b/wrappers/python/eduvpn_common/server.py index 470f704..36c3643 100644 --- a/wrappers/python/eduvpn_common/server.py +++ b/wrappers/python/eduvpn_common/server.py @@ -1,5 +1,5 @@ -from eduvpn_common import lib, cServer, cServers, cServerLocations, cServerProfiles -from ctypes import cast, POINTER, c_char_p +from eduvpn_common.types import cServer, cServers, cServerLocations, cServerProfiles +from ctypes import cast, POINTER from datetime import datetime @@ -125,19 +125,19 @@ def get_server(ptr, _type=None): return Server(identifier, display_name, profiles, current_server.expire_time) -def get_transition_server(ptr): +def get_transition_server(lib, ptr): server = get_server(cast(ptr, POINTER(cServer))) lib.FreeServer(ptr) return server -def get_transition_profiles(ptr): +def get_transition_profiles(lib, ptr): profiles = get_profiles(cast(ptr, POINTER(cServerProfiles))) lib.FreeProfiles(ptr) return profiles -def get_servers(ptr): +def get_servers(lib, ptr): if ptr: returned = [] servers = cast(ptr, POINTER(cServers)).contents @@ -164,7 +164,7 @@ def get_servers(ptr): return None -def get_locations(ptr): +def get_locations(lib, ptr): if ptr: locations = cast(ptr, POINTER(cServerLocations)).contents location_list = [] diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py new file mode 100644 index 0000000..c543989 --- /dev/null +++ b/wrappers/python/eduvpn_common/types.py @@ -0,0 +1,175 @@ +from ctypes import * +from eduvpn_common.error import ErrorLevel, WrappedError +from typing import List, Optional, Tuple + + +class cError(Structure): + _fields_ = [ + ("level", c_int), + ("traceback", c_char_p), + ("cause", c_char_p), + ] + + +class cServerLocations(Structure): + _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)] + + +class cDiscoveryOrganization(Structure): + _fields_ = [ + ("display_name", c_char_p), + ("org_id", c_char_p), + ("secure_internet_home", c_char_p), + ("keyword_list", c_char_p), + ] + + +class cDiscoveryOrganizations(Structure): + _fields_ = [ + ("version", c_ulonglong), + ("organizations", POINTER(POINTER(cDiscoveryOrganization))), + ("total_organizations", c_size_t), + ] + + +class cDiscoveryServer(Structure): + _fields_ = [ + ("authentication_url_template", c_char_p), + ("base_url", c_char_p), + ("country_code", c_char_p), + ("display_name", c_char_p), + ("keyword_list", c_char_p), + ("public_key_list", POINTER(c_char_p)), + ("total_public_keys", c_size_t), + ("server_type", c_char_p), + ("support_contact", POINTER(c_char_p)), + ("total_support_contact", c_size_t), + ] + + +class cDiscoveryServers(Structure): + _fields_ = [ + ("version", c_ulonglong), + ("servers", POINTER(POINTER(cDiscoveryServer))), + ("total_servers", c_size_t), + ] + + +class cServerProfile(Structure): + _fields_ = [ + ("identifier", c_char_p), + ("display_name", c_char_p), + ("default_gateway", c_int), + ] + + +class cServerProfiles(Structure): + _fields_ = [ + ("current", c_int), + ("profiles", POINTER(POINTER(cServerProfile))), + ("total_profiles", c_size_t), + ] + + +class cServer(Structure): + _fields_ = [ + ("identifier", c_char_p), + ("display_name", c_char_p), + ("server_type", c_char_p), + ("country_code", c_char_p), + ("support_contact", POINTER(c_char_p)), + ("total_support_contact", c_size_t), + ("profiles", POINTER(cServerProfiles)), + ("expire_time", c_ulonglong), + ] + + +class cServers(Structure): + _fields_ = [ + ("custom_servers", POINTER(POINTER(cServer))), + ("total_custom", c_size_t), + ("institute_servers", POINTER(POINTER(cServer))), + ("total_institute", c_size_t), + ("secure_internet", POINTER(cServer)), + ] + + +class DataError(Structure): + _fields_ = [("data", c_void_p), ("error", c_void_p)] + + +class ConfigError(Structure): + _fields_ = [("config", c_void_p), ("config_type", c_void_p), ("error", c_void_p)] + + +VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) + + +def encode_args(args, types): + for arg, t in zip(args, types): + # c_char_p needs the str to be encoded to bytes + if t is c_char_p: + arg = arg.encode("utf-8") + yield arg + + +def decode_res(t): + decode_map = { + c_int: get_bool, + c_void_p: get_error, + DataError: get_data_error, + ConfigError: get_config_error, + } + return decode_map.get(t, lambda lib, x: x) + + +def get_ptr_string(lib, ptr: c_void_p) -> str: + if ptr: + string = cast(ptr, c_char_p).value + lib.FreeString(ptr) + if string: + return string.decode("utf-8") + return "" + + +def get_ptr_list_strings( + lib, strings: POINTER(c_char_p), total_strings: c_size_t +) -> List[str]: + if strings: + strings_list = [] + for i in range(int(total_strings)): + strings_list.append(strings[i].decode("utf-8")) + return strings_list + return [] + + +def get_error(lib, ptr: c_void_p) -> Optional[WrappedError]: + if not ptr: + return None + err = cast(ptr, POINTER(cError)).contents + wrapped = WrappedError( + err.traceback.decode("utf-8"), err.cause.decode("utf-8"), ErrorLevel(err.level) + ) + lib.FreeError(ptr) + return wrapped + + +def get_config_error( + lib, config_error: ConfigError +) -> Tuple[str, str, Optional[WrappedError]]: + config = get_ptr_string(lib, config_error.config) + config_type = get_ptr_string(lib, config_error.config_type) + err = get_error(lib, config_error.error) + return config, config_type, err + + +def get_data_error( + lib, data_error: DataError, data_conv=get_ptr_string +) -> Tuple[str, Optional[WrappedError]]: + data = data_conv(lib, data_error.data) + error = get_error(lib, data_error.error) + return data, error + + +def get_bool(lib, boolInt: c_int) -> bool: + return boolInt == 1 |
