diff options
Diffstat (limited to 'wrappers')
| -rw-r--r-- | wrappers/python/eduvpn_common/event.py | 14 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/loader.py | 6 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/main.py | 81 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/server.py | 52 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/types.py | 24 |
5 files changed, 89 insertions, 88 deletions
diff --git a/wrappers/python/eduvpn_common/event.py b/wrappers/python/eduvpn_common/event.py index f6260c5..1e42ef4 100644 --- a/wrappers/python/eduvpn_common/event.py +++ b/wrappers/python/eduvpn_common/event.py @@ -1,5 +1,5 @@ -from enum import Enum -from typing import Callable +from ctypes import c_void_p, CDLL +from typing import Any, Callable, Dict, List, Tuple from eduvpn_common.state import State, StateType from eduvpn_common.server import ( get_locations, @@ -22,7 +22,7 @@ def class_state_transition(state: int, state_type: StateType) -> Callable: return wrapper -def convert_data(lib, state: State, data): +def convert_data(lib: CDLL, state: int, data: Any): if not data: return None if state is State.NO_SERVER: @@ -43,11 +43,11 @@ def convert_data(lib, state: State, data): class EventHandler(object): - def __init__(self, lib): - self.handlers = {} + def __init__(self, lib: CDLL): + self.handlers: Dict[Tuple[int, StateType], List[Callable]] = {} self.lib = lib - def change_class_callbacks(self, cls, add=True) -> None: + def change_class_callbacks(self, cls: Any, add: bool = True) -> None: # Loop over method names for method_name in dir(cls): try: @@ -98,7 +98,7 @@ class EventHandler(object): func(other_state, data) def run( - self, old_state: int, new_state: int, data: str, convert: bool = True + self, old_state: int, new_state: int, data: Any, convert: bool = True ) -> None: # First run leave transitions, then enter # The state is done when the wait event finishes diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py index 23851f3..a5eec3f 100644 --- a/wrappers/python/eduvpn_common/loader.py +++ b/wrappers/python/eduvpn_common/loader.py @@ -1,9 +1,9 @@ -from ctypes import * +from ctypes import cdll, CDLL, c_char_p, c_int, c_void_p from collections import defaultdict import pathlib import platform from eduvpn_common import __version__ -from eduvpn_common.types import * +from eduvpn_common.types import cError, cServer, cServers, cServerProfiles, cServerLocations, cDiscoveryServer, cDiscoveryServers, ConfigError, DataError, VPNStateChange def load_lib(): @@ -39,7 +39,7 @@ def load_lib(): return lib -def initialize_functions(lib): +def initialize_functions(lib: CDLL): # Exposed functions # We have to use c_void_p instead of c_char_p to free it properly # See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index 382f356..1271fb2 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -1,4 +1,5 @@ -from typing import Optional, Tuple +from ctypes import c_char_p, c_int, c_void_p +from typing import Any, Callable, Dict, Optional, Tuple import threading from eduvpn_common.discovery import get_disco_organizations, get_disco_servers from eduvpn_common.event import EventHandler @@ -7,30 +8,6 @@ from eduvpn_common.types import VPNStateChange, encode_args, decode_res, get_dat from eduvpn_common.server import get_servers from eduvpn_common.state import State, StateType -eduvpn_objects = {} - - -def add_as_global_object(eduvpn) -> bool: - global eduvpn_objects - if eduvpn.name not in eduvpn_objects: - eduvpn_objects[eduvpn.name] = eduvpn - return True - return False - - -def remove_as_global_object(eduvpn): - global eduvpn_objects - eduvpn_objects.pop(eduvpn.name, None) - - -@VPNStateChange -def state_callback(name, old_state, new_state, data): - name = name.decode() - if name not in eduvpn_objects: - return - eduvpn_objects[name].callback(State(old_state), State(new_state), data) - - class EduVPN(object): def __init__(self, name: str, config_directory: str, language: str): self.name = name @@ -60,17 +37,14 @@ class EduVPN(object): if self.location_event: self.location_event.wait() - def go_function(self, func, *args): + def go_function(self, func: Any, *args, decode_func: Optional[Callable] = None): # The functions all have at least one arg type which is the name of the client args_gen = encode_args(list(args), func.argtypes[1:]) res = func(self.name.encode("utf-8"), *(args_gen)) - return decode_res(func.restype)(self.lib, res) - - def go_function_custom_decode(self, func, decode_func, *args): - # The functions all have at least one arg type which is the name of the client - args_gen = encode_args(list(args), func.argtypes[1:]) - res = func(self.name.encode("utf-8"), *(args_gen)) - return decode_func(self.lib, res) + if decode_func is None: + return decode_res(func.restype)(self.lib, res) + else: + return decode_func(self.lib, res) def cancel_oauth(self) -> None: cancel_oauth_err = self.go_function(self.lib.CancelOAuth) @@ -94,7 +68,7 @@ class EduVPN(object): raise register_err def get_disco_servers(self) -> str: - servers, servers_err = self.go_function_custom_decode( + servers, servers_err = self.go_function( self.lib.GetDiscoServers, decode_func=lambda lib, x: get_data_error(lib, x, get_disco_servers), ) @@ -105,7 +79,7 @@ class EduVPN(object): return servers def get_disco_organizations(self) -> str: - organizations, organizations_err = self.go_function_custom_decode( + organizations, organizations_err = self.go_function( self.lib.GetDiscoOrganizations, decode_func=lambda lib, x: get_data_error(lib, x, get_disco_organizations), ) @@ -152,7 +126,7 @@ class EduVPN(object): if remove_err: raise remove_err - def get_config(self, url: str, func: callable, prefer_tcp: bool = False): + def get_config(self, url: str, func: Any, prefer_tcp: bool = False): # Because it could be the case that a profile callback is started, store a threading event # In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set # The event is set in self.set_profile @@ -205,7 +179,7 @@ class EduVPN(object): if connecting_err: raise connecting_err - def set_disconnected(self, cleanup=True) -> None: + def set_disconnected(self, cleanup: bool = True) -> None: disconnect_err = self.go_function(self.lib.SetDisconnected, cleanup) if disconnect_err: @@ -217,17 +191,17 @@ class EduVPN(object): if search_err: raise search_err - def remove_class_callbacks(self, cls) -> None: + def remove_class_callbacks(self, cls: Any) -> None: self.event_handler.change_class_callbacks(cls, add=False) - def register_class_callbacks(self, cls) -> None: + def register_class_callbacks(self, cls: Any) -> None: self.event_handler.change_class_callbacks(cls) @property def event(self) -> EventHandler: return self.event_handler - def callback(self, old_state: State, new_state: State, data) -> None: + def callback(self, old_state: State, new_state: State, data: Any) -> None: self.event.run(old_state, new_state, data) def set_profile(self, profile_id: str) -> None: @@ -275,7 +249,7 @@ class EduVPN(object): return self.go_function(self.lib.InFSMState, state_id) def get_saved_servers(self): - servers, servers_err = self.go_function_custom_decode( + servers, servers_err = self.go_function( self.lib.GetSavedServers, decode_func=lambda lib, x: get_data_error(lib, x, get_servers), ) @@ -284,3 +258,28 @@ class EduVPN(object): raise servers_err return servers + +eduvpn_objects: Dict[str, EduVPN] = {} + + +@VPNStateChange +def state_callback(name: bytes, old_state: int, new_state: int, data: Any): + name_decoded = name.decode() + if name_decoded not in eduvpn_objects: + return + eduvpn_objects[name_decoded].callback(State(old_state), State(new_state), data) + + + +def add_as_global_object(eduvpn: EduVPN) -> bool: + global eduvpn_objects + if eduvpn.name not in eduvpn_objects: + eduvpn_objects[eduvpn.name] = eduvpn + return True + return False + + +def remove_as_global_object(eduvpn: EduVPN): + global eduvpn_objects + eduvpn_objects.pop(eduvpn.name, None) + diff --git a/wrappers/python/eduvpn_common/server.py b/wrappers/python/eduvpn_common/server.py index 71b6487..01b5204 100644 --- a/wrappers/python/eduvpn_common/server.py +++ b/wrappers/python/eduvpn_common/server.py @@ -1,10 +1,11 @@ +from typing import List, Optional, Type from eduvpn_common.types import cServer, cServers, cServerLocations, cServerProfiles -from ctypes import cast, POINTER +from ctypes import c_void_p, cast, POINTER, CDLL from datetime import datetime class Profile: - def __init__(self, identifier, display_name, default_gateway: bool): + def __init__(self, identifier: str, display_name: str, default_gateway: bool): self.identifier = identifier self.display_name = display_name self.default_gateway = default_gateway @@ -14,19 +15,19 @@ class Profile: class Profiles: - def __init__(self, profiles, current): + def __init__(self, profiles: List[Profile], current: int): self.profiles = profiles self.current_index = current @property - def current(self): + def current(self) -> Optional[Profile]: if self.current_index < len(self.profiles): return self.profiles[self.current_index] return None class Server: - def __init__(self, url, display_name, profiles=None, expire_time=0): + def __init__(self, url: str, display_name: str, profiles: Optional[Profiles] = None, expire_time: int = 0): self.url = url self.display_name = display_name self.profiles = profiles @@ -36,29 +37,28 @@ class Server: return self.display_name @property - def category(self): + def category(self) -> str: return "Custom Server" class InstituteServer(Server): - def __init__(self, url, display_name, support_contact, profiles, expire_time): + def __init__(self, url: str, display_name: str, support_contact: List[str], profiles: Profiles, expire_time: int): super().__init__(url, display_name, profiles, expire_time) self.support_contact = support_contact @property - def category(self): + def category(self) -> str: return "Institute Access Server" - class SecureInternetServer(Server): def __init__( self, - org_id, - display_name, - support_contact, - profiles, - expire_time, - country_code, + org_id: str, + display_name: str, + support_contact: List[str], + profiles: Profiles, + expire_time: int, + country_code: str, ): super().__init__(org_id, display_name, profiles, expire_time) self.org_id = org_id @@ -66,11 +66,11 @@ class SecureInternetServer(Server): self.country_code = country_code @property - def category(self): + def category(self) -> str: return "Secure Internet Server" -def get_type_for_str(type_str: str): +def get_type_for_str(type_str: str) -> Type[Server]: if type_str == "secure_internet": return SecureInternetServer if type_str == "custom_server": @@ -78,14 +78,14 @@ def get_type_for_str(type_str: str): return InstituteServer -def get_profiles(ptr): +def get_profiles(ptr) -> Optional[Profiles]: if not ptr: - return [] + return None profiles = [] _profiles = ptr.contents current_profile = _profiles.current if not _profiles.profiles: - return [] + return None for i in range(_profiles.total_profiles): if not _profiles.profiles[i]: continue @@ -100,7 +100,7 @@ def get_profiles(ptr): return Profiles(profiles, current_profile) -def get_server(ptr, _type=None): +def get_server(ptr, _type=None) -> Optional[Server]: if not ptr: return None @@ -116,6 +116,8 @@ def get_server(ptr, _type=None): for i in range(current_server.total_support_contact): support_contact.append(current_server.support_contact[i].decode("utf-8")) profiles = get_profiles(current_server.profiles) + if profiles is None: + return None if _type is SecureInternetServer: return SecureInternetServer( identifier, @@ -136,19 +138,19 @@ def get_server(ptr, _type=None): return Server(identifier, display_name, profiles, current_server.expire_time) -def get_transition_server(lib, ptr): +def get_transition_server(lib: CDLL, ptr: c_void_p) -> Optional[Server]: server = get_server(cast(ptr, POINTER(cServer))) lib.FreeServer(ptr) return server -def get_transition_profiles(lib, ptr): +def get_transition_profiles(lib: CDLL, ptr: c_void_p) -> Optional[Profiles]: profiles = get_profiles(cast(ptr, POINTER(cServerProfiles))) lib.FreeProfiles(ptr) return profiles -def get_servers(lib, ptr): +def get_servers(lib: CDLL, ptr: c_void_p) -> Optional[List[Server]]: if ptr: returned = [] servers = cast(ptr, POINTER(cServers)).contents @@ -175,7 +177,7 @@ def get_servers(lib, ptr): return None -def get_locations(lib, ptr): +def get_locations(lib: CDLL, ptr: c_void_p) -> Optional[List[str]]: if ptr: locations = cast(ptr, POINTER(cServerLocations)).contents location_list = [] diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py index c543989..82987c0 100644 --- a/wrappers/python/eduvpn_common/types.py +++ b/wrappers/python/eduvpn_common/types.py @@ -1,6 +1,6 @@ -from ctypes import * +from ctypes import Structure, c_int, c_char_p, c_size_t, c_ulonglong, c_void_p, CFUNCTYPE, POINTER, CDLL, cast, pointer from eduvpn_common.error import ErrorLevel, WrappedError -from typing import List, Optional, Tuple +from typing import Any, Callable, Iterator, List, Optional, Tuple class cError(Structure): @@ -105,7 +105,7 @@ class ConfigError(Structure): VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) -def encode_args(args, types): +def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]: for arg, t in zip(args, types): # c_char_p needs the str to be encoded to bytes if t is c_char_p: @@ -113,17 +113,17 @@ def encode_args(args, types): yield arg -def decode_res(t): +def decode_res(res: Any): decode_map = { c_int: get_bool, c_void_p: get_error, DataError: get_data_error, ConfigError: get_config_error, } - return decode_map.get(t, lambda lib, x: x) + return decode_map.get(res, lambda lib, x: x) -def get_ptr_string(lib, ptr: c_void_p) -> str: +def get_ptr_string(lib: CDLL, ptr: c_void_p) -> str: if ptr: string = cast(ptr, c_char_p).value lib.FreeString(ptr) @@ -133,17 +133,17 @@ def get_ptr_string(lib, ptr: c_void_p) -> str: def get_ptr_list_strings( - lib, strings: POINTER(c_char_p), total_strings: c_size_t + lib: CDLL, strings: pointer, total_strings: int ) -> List[str]: if strings: strings_list = [] - for i in range(int(total_strings)): + for i in range(total_strings): strings_list.append(strings[i].decode("utf-8")) return strings_list return [] -def get_error(lib, ptr: c_void_p) -> Optional[WrappedError]: +def get_error(lib: CDLL, ptr: c_void_p) -> Optional[WrappedError]: if not ptr: return None err = cast(ptr, POINTER(cError)).contents @@ -155,7 +155,7 @@ def get_error(lib, ptr: c_void_p) -> Optional[WrappedError]: def get_config_error( - lib, config_error: ConfigError + lib: CDLL, config_error: ConfigError ) -> Tuple[str, str, Optional[WrappedError]]: config = get_ptr_string(lib, config_error.config) config_type = get_ptr_string(lib, config_error.config_type) @@ -164,12 +164,12 @@ def get_config_error( def get_data_error( - lib, data_error: DataError, data_conv=get_ptr_string + lib: CDLL, data_error: DataError, data_conv: Callable = get_ptr_string ) -> Tuple[str, Optional[WrappedError]]: data = data_conv(lib, data_error.data) error = get_error(lib, data_error.error) return data, error -def get_bool(lib, boolInt: c_int) -> bool: +def get_bool(lib: CDLL, boolInt: c_int) -> bool: return boolInt == 1 |
