diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-16 10:46:28 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-16 10:46:28 +0200 |
| commit | 4bf1273c3f446ac3195fb700ec41c7cae7d20ac9 (patch) | |
| tree | cec8d9e405b7d6786023ca9b921a6f0473d28a71 /wrappers | |
| parent | 02db081c85e56e6472c2f39e6a623fa4cdf359c4 (diff) | |
Discovery: Expose c types
Diffstat (limited to 'wrappers')
| -rw-r--r-- | wrappers/python/src/__init__.py | 73 | ||||
| -rw-r--r-- | wrappers/python/src/discovery.py | 85 | ||||
| -rw-r--r-- | wrappers/python/src/event.py | 26 | ||||
| -rw-r--r-- | wrappers/python/src/main.py | 20 | ||||
| -rw-r--r-- | wrappers/python/src/server.py | 109 |
5 files changed, 247 insertions, 66 deletions
diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py index ece4b46..ca07143 100644 --- a/wrappers/python/src/__init__.py +++ b/wrappers/python/src/__init__.py @@ -5,6 +5,7 @@ import pathlib import platform from typing import Tuple, Optional import json +from typing import List _lib_prefixes = defaultdict( lambda: "lib", @@ -40,11 +41,10 @@ class ErrorLevel(Enum): ERR_OTHER = 0 ERR_INFO = 1 + class cServerLocations(Structure): - _fields_ = [ - ("locations", POINTER(c_char_p)), - ("total_locations", c_size_t) - ] + _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)] + class cDiscoveryOrganization(Structure): _fields_ = [ @@ -54,6 +54,7 @@ class cDiscoveryOrganization(Structure): ("keyword_list", c_char_p), ] + class cDiscoveryOrganizations(Structure): _fields_ = [ ("version", c_ulonglong), @@ -61,6 +62,30 @@ class cDiscoveryOrganizations(Structure): ("total_organizations", c_size_t), ] + +class cDiscoveryServer(Structure): + _fields_ = [ + ("authentication_url_template", c_char_p), + ("base_url", c_char_p), + ("country_code", c_char_p), + ("display_name", c_char_p), + ("keyword_list", c_char_p), + ("public_key_list", POINTER(c_char_p)), + ("total_public_keys", c_size_t), + ("server_type", c_char_p), + ("support_contact", POINTER(c_char_p)), + ("total_support_contact", c_size_t), + ] + + +class cDiscoveryServers(Structure): + _fields_ = [ + ("version", c_ulonglong), + ("servers", POINTER(POINTER(cDiscoveryServer))), + ("total_servers", c_size_t), + ] + + class cServerProfile(Structure): _fields_ = [ ("identifier", c_char_p), @@ -68,6 +93,7 @@ class cServerProfile(Structure): ("default_gateway", c_int), ] + class cServerProfiles(Structure): _fields_ = [ ("current", c_int), @@ -75,10 +101,12 @@ class cServerProfiles(Structure): ("total_profiles", c_size_t), ] + class cServer(Structure): _fields_ = [ ("identifier", c_char_p), ("display_name", c_char_p), + ("server_type", c_char_p), ("country_code", c_char_p), ("support_contact", POINTER(c_char_p)), ("total_support_contact", c_size_t), @@ -86,6 +114,7 @@ class cServer(Structure): ("expire_time", c_ulonglong), ] + class cServers(Structure): _fields_ = [ ("custom_servers", POINTER(POINTER(cServer))), @@ -95,6 +124,7 @@ class cServers(Structure): ("secure_internet", POINTER(cServer)), ] + class DataError(Structure): _fields_ = [("data", c_void_p), ("error", c_void_p)] @@ -104,9 +134,17 @@ VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) # Exposed functions # We have to use c_void_p instead of c_char_p to free it properly # See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error -lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [c_char_p], c_void_p -lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [c_char_p, c_char_p], c_void_p -lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [c_char_p, c_char_p], c_void_p +lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [ + c_char_p +], c_void_p +lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [ + c_char_p, + c_char_p, +], c_void_p +lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [ + c_char_p, + c_char_p, +], c_void_p lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [ c_char_p, c_char_p, @@ -132,7 +170,7 @@ lib.Register.argtypes, lib.Register.restype = [ lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [ c_char_p ], c_void_p -lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError +lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], c_void_p lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p @@ -150,9 +188,14 @@ lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p +lib.FreeProfiles.argtypes, lib.FreeProfiles.restype = [c_void_p], None lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None -lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [c_void_p], None +lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [ + c_void_p +], None +lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None +lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], c_void_p @@ -186,6 +229,17 @@ def get_ptr_string(ptr: c_void_p) -> str: return "" +def get_ptr_list_strings( + strings: POINTER(c_char_p), total_strings: c_size_t +) -> List[str]: + if strings: + strings_list = [] + for i in range(total_strings): + strings_list.append(strings[i].decode("utf-8")) + return strings_list + return [] + + def get_ptr_error(ptr: c_void_p) -> Optional[WrappedError]: error_string = get_ptr_string(ptr) @@ -224,6 +278,7 @@ def get_data_error(data_error: DataError) -> Tuple[str, str]: def get_bool(boolInt: c_int) -> bool: return boolInt == 1 + decode_map = { c_int: get_bool, c_void_p: get_error, diff --git a/wrappers/python/src/discovery.py b/wrappers/python/src/discovery.py index 80c08cf..a1b87ea 100644 --- a/wrappers/python/src/discovery.py +++ b/wrappers/python/src/discovery.py @@ -1,4 +1,4 @@ -from . import lib, cDiscoveryOrganizations +from . import lib, cDiscoveryOrganizations, cDiscoveryServers, get_ptr_list_strings from ctypes import cast, POINTER @@ -9,6 +9,9 @@ class DiscoOrganization: self.secure_internet_home = secure_internet_home self.keyword_list = keyword_list + def __str__(self): + return self.display_name + class DiscoOrganizations: def __init__(self, version, organizations): @@ -16,6 +19,37 @@ class DiscoOrganizations: self.organizations = organizations +class DiscoServer: + def __init__( + self, + authentication_url_template, + base_url, + country_code, + display_name, + keyword_list, + public_keys, + server_type, + support_contacts, + ): + self.authentication_url_template = authentication_url_template + self.base_url = base_url + self.country_code = country_code + self.display_name = display_name + self.keyword_list = keyword_list + self.public_keys = public_keys + self.server_type = server_type + self.support_contacts = support_contacts + + def __str__(self): + return self.display_name + + +class DiscoServers: + def __init__(self, version, servers): + self.version = version + self.servers = servers + + def get_disco_organization(ptr): if not ptr: return None @@ -28,6 +62,55 @@ def get_disco_organization(ptr): return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list) +def get_disco_server(ptr): + if not ptr: + return None + + current_server = ptr.contents + authentication_url_template = current_server.authentication_url_template.decode( + "utf-8" + ) + base_url = current_server.base_url.decode("utf-8") + country_code = current_server.country_code.decode("utf-8") + display_name = current_server.display_name.decode("utf-8") + keyword_list = current_server.keyword_list.decode("utf-8") + public_keys = get_ptr_list_strings( + current_server.public_key_list, current_server.total_public_keys + ) + server_type = current_server.server_type.decode("utf-8") + support_contacts = get_ptr_list_strings( + current_server.support_contact, current_server.total_support_contact + ) + return DiscoServer( + authentication_url_template, + base_url, + country_code, + display_name, + keyword_list, + public_keys, + server_type, + support_contacts, + ) + + +def get_disco_servers(ptr): + if ptr: + svrs = cast(ptr, POINTER(cDiscoveryServers)).contents + + servers = [] + + if svrs.servers: + for i in range(svrs.total_servers): + current = get_disco_server(svrs.servers[i]) + + if current is None: + continue + servers.append(current) + lib.FreeDiscoServers(ptr) + return DiscoServers(svrs.version, servers) + return None + + def get_disco_organizations(ptr): if ptr: orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents diff --git a/wrappers/python/src/event.py b/wrappers/python/src/event.py index 0e0f5ae..cf1a9d0 100644 --- a/wrappers/python/src/event.py +++ b/wrappers/python/src/event.py @@ -2,7 +2,12 @@ from . import VPNStateChange, get_ptr_string from enum import Enum from typing import Callable from .state import State, StateType -from .server import get_locations, get_servers +from .server import ( + get_locations, + get_transition_profiles, + get_transition_server, + get_servers, +) EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback" @@ -16,6 +21,7 @@ def class_state_transition(state: int, state_type: StateType) -> Callable: return wrapper + def convert_data(state: State, data): if not data: return None @@ -25,6 +31,16 @@ def convert_data(state: State, data): return get_ptr_string(data) if state is State.ASK_LOCATION: return get_locations(data) + if state is State.ASK_PROFILE: + return get_transition_profiles(data) + if state in [ + State.DISCONNECTED, + State.DISCONNECTING, + State.CONNECTING, + State.CONNECTED, + ]: + return get_transition_server(data) + class EventHandler(object): def __init__(self): @@ -80,10 +96,14 @@ class EventHandler(object): for func in self.handlers[(state, state_type)]: func(other_state, data) - def run(self, old_state: int, new_state: int, data: str) -> None: + def run( + self, old_state: int, new_state: int, data: str, convert: bool = True + ) -> None: # First run leave transitions, then enter # The state is done when the wait event finishes - converted = convert_data(new_state, data) + converted = data + if convert: + converted = convert_data(new_state, data) self.run_state(old_state, new_state, StateType.Leave, converted) self.run_state(new_state, old_state, StateType.Enter, converted) self.run_state(new_state, old_state, StateType.Wait, converted) diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py index 8425992..e20b84f 100644 --- a/wrappers/python/src/main.py +++ b/wrappers/python/src/main.py @@ -1,7 +1,7 @@ from . import lib, VPNStateChange, encode_args, decode_res from typing import Optional, Tuple import threading -from .discovery import get_disco_organizations +from .discovery import get_disco_organizations, get_disco_servers from .event import EventHandler from .state import State, StateType from .server import get_servers @@ -88,16 +88,20 @@ class EduVPN(object): raise Exception(register_err) def get_disco_servers(self) -> str: - servers, servers_err = self.go_function(lib.GetDiscoServers) + servers = self.go_function_custom_decode( + lib.GetDiscoServers, decode_func=get_disco_servers + ) - if servers_err: - raise Exception(servers_err) + # if servers_err: + # raise Exception(servers_err) return servers def get_disco_organizations(self) -> str: - organizations = self.go_function_custom_decode(lib.GetDiscoOrganizations, decode_func=get_disco_organizations) - #if organizations_err: + organizations = self.go_function_custom_decode( + lib.GetDiscoOrganizations, decode_func=get_disco_organizations + ) + # if organizations_err: # raise Exception(organizations_err) return organizations @@ -251,4 +255,6 @@ class EduVPN(object): return self.go_function(lib.GetSavedServersOLD) def get_saved_servers_new(self) -> str: - return self.go_function_custom_decode(lib.GetSavedServersNEW, decode_func=get_servers) + return self.go_function_custom_decode( + lib.GetSavedServersNEW, decode_func=get_servers + ) diff --git a/wrappers/python/src/server.py b/wrappers/python/src/server.py index b765ede..dce5d51 100644 --- a/wrappers/python/src/server.py +++ b/wrappers/python/src/server.py @@ -1,5 +1,6 @@ -from . import lib, cServers, cServerLocations -from ctypes import cast, POINTER +from . import lib, cServer, cServers, cServerLocations, cServerProfiles +from ctypes import cast, POINTER, c_char_p +from datetime import datetime class Profile: @@ -9,61 +10,85 @@ class Profile: self.default_gateway = default_gateway def __str__(self): - return f"Profile: {self.display_name}" + return self.display_name + + +class Profiles: + def __init__(self, profiles, current): + self.profiles = profiles + self.current_index = current + + @property + def current(self): + if self.current_index < len(self.profiles): + return self.profiles[self.current_index] + return None class Server: - def __init__(self, url, display_name, profiles, current_profile, expire_time): + def __init__(self, url, display_name, profiles=None, expire_time=0): self.url = url self.display_name = display_name self.profiles = profiles self.current_profile = None - if current_profile < len(profiles): - self.current_profile = profiles[current_profile] - self.expire_time = expire_time + self.expire_time = datetime.fromtimestamp(expire_time) def __str__(self): - return f"Server: {self.url}, with current profile: {self.current_profile}" + return self.display_name class InstituteServer(Server): - def __init__( - self, url, display_name, support_contact, profiles, current_profile, expire_time - ): - super().__init__(url, display_name, profiles, current_profile, expire_time) + def __init__(self, url, display_name, support_contact, profiles, expire_time): + super().__init__(url, display_name, profiles, expire_time) self.support_contact = support_contact - def __str__(self): - return f"Institute Server: {self.display_name}" - class SecureInternetServer(Server): def __init__( self, - url, + org_id, display_name, support_contact, profiles, - current_profile, expire_time, country_code, ): - super().__init__(url, display_name, profiles, current_profile, expire_time) + super().__init__(org_id, display_name, profiles, expire_time) + self.org_id = org_id self.support_contact = support_contact self.country_code = country_code - def __str__(self): - return f"Secure Internet Server: {self.display_name} with country {self.country_code}" - def get_type_for_str(type_str: str): - if type_str is "secure_internet": + if type_str == "secure_internet": return SecureInternetServer - if type_str is "custom_server": + if type_str == "custom_server": return Server return InstituteServer +def get_profiles(ptr): + if not ptr: + return [] + profiles = [] + _profiles = ptr.contents + current_profile = _profiles.current + if not _profiles.profiles: + return [] + for i in range(_profiles.total_profiles): + if not _profiles.profiles[i]: + continue + profile = _profiles.profiles[i].contents + profiles.append( + Profile( + profile.identifier.decode("utf-8"), + profile.display_name.decode("utf-8"), + profile.default_gateway == 1, + ) + ) + return Profiles(profiles, current_profile) + + def get_server(ptr, _type=None): if not ptr: return None @@ -79,31 +104,13 @@ def get_server(ptr, _type=None): support_contact = [] for i in range(current_server.total_support_contact): support_contact.append(current_server.support_contact[i].decode("utf-8")) - profiles = [] - if not current_server.profiles: - return None - - _profiles = current_server.profiles.contents - current_profile = _profiles.current - for i in range(_profiles.total_profiles): - if not _profiles.profiles or not _profiles.profiles[i]: - return None - profile = _profiles.profiles[i].contents - profiles.append( - Profile( - profile.identifier.decode("utf-8"), - profile.display_name.decode("utf-8"), - profile.default_gateway == 1, - ) - ) - + profiles = get_profiles(current_server.profiles) if _type is SecureInternetServer: return SecureInternetServer( identifier, display_name, support_contact, profiles, - current_profile, current_server.expire_time, current_server.country_code.decode("utf-8"), ) @@ -113,12 +120,21 @@ def get_server(ptr, _type=None): display_name, support_contact, profiles, - current_profile, current_server.expire_time, ) - return Server( - identifier, display_name, profiles, current_profile, current_server.expire_time - ) + return Server(identifier, display_name, profiles, current_server.expire_time) + + +def get_transition_server(ptr): + server = get_server(cast(ptr, POINTER(cServer))) + lib.FreeServer(ptr) + return server + + +def get_transition_profiles(ptr): + profiles = get_profiles(cast(ptr, POINTER(cServerProfiles))) + lib.FreeProfiles(ptr) + return profiles def get_servers(ptr): @@ -147,6 +163,7 @@ def get_servers(ptr): return returned return None + def get_locations(ptr): if ptr: locations = cast(ptr, POINTER(cServerLocations)).contents |
