summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--wrappers/python/eduvpn_common/event.py14
-rw-r--r--wrappers/python/eduvpn_common/loader.py6
-rw-r--r--wrappers/python/eduvpn_common/main.py81
-rw-r--r--wrappers/python/eduvpn_common/server.py52
-rw-r--r--wrappers/python/eduvpn_common/types.py24
5 files changed, 89 insertions, 88 deletions
diff --git a/wrappers/python/eduvpn_common/event.py b/wrappers/python/eduvpn_common/event.py
index f6260c5..1e42ef4 100644
--- a/wrappers/python/eduvpn_common/event.py
+++ b/wrappers/python/eduvpn_common/event.py
@@ -1,5 +1,5 @@
-from enum import Enum
-from typing import Callable
+from ctypes import c_void_p, CDLL
+from typing import Any, Callable, Dict, List, Tuple
from eduvpn_common.state import State, StateType
from eduvpn_common.server import (
get_locations,
@@ -22,7 +22,7 @@ def class_state_transition(state: int, state_type: StateType) -> Callable:
return wrapper
-def convert_data(lib, state: State, data):
+def convert_data(lib: CDLL, state: int, data: Any):
if not data:
return None
if state is State.NO_SERVER:
@@ -43,11 +43,11 @@ def convert_data(lib, state: State, data):
class EventHandler(object):
- def __init__(self, lib):
- self.handlers = {}
+ def __init__(self, lib: CDLL):
+ self.handlers: Dict[Tuple[int, StateType], List[Callable]] = {}
self.lib = lib
- def change_class_callbacks(self, cls, add=True) -> None:
+ def change_class_callbacks(self, cls: Any, add: bool = True) -> None:
# Loop over method names
for method_name in dir(cls):
try:
@@ -98,7 +98,7 @@ class EventHandler(object):
func(other_state, data)
def run(
- self, old_state: int, new_state: int, data: str, convert: bool = True
+ self, old_state: int, new_state: int, data: Any, convert: bool = True
) -> None:
# First run leave transitions, then enter
# The state is done when the wait event finishes
diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py
index 23851f3..a5eec3f 100644
--- a/wrappers/python/eduvpn_common/loader.py
+++ b/wrappers/python/eduvpn_common/loader.py
@@ -1,9 +1,9 @@
-from ctypes import *
+from ctypes import cdll, CDLL, c_char_p, c_int, c_void_p
from collections import defaultdict
import pathlib
import platform
from eduvpn_common import __version__
-from eduvpn_common.types import *
+from eduvpn_common.types import cError, cServer, cServers, cServerProfiles, cServerLocations, cDiscoveryServer, cDiscoveryServers, ConfigError, DataError, VPNStateChange
def load_lib():
@@ -39,7 +39,7 @@ def load_lib():
return lib
-def initialize_functions(lib):
+def initialize_functions(lib: CDLL):
# Exposed functions
# We have to use c_void_p instead of c_char_p to free it properly
# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error
diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py
index 382f356..1271fb2 100644
--- a/wrappers/python/eduvpn_common/main.py
+++ b/wrappers/python/eduvpn_common/main.py
@@ -1,4 +1,5 @@
-from typing import Optional, Tuple
+from ctypes import c_char_p, c_int, c_void_p
+from typing import Any, Callable, Dict, Optional, Tuple
import threading
from eduvpn_common.discovery import get_disco_organizations, get_disco_servers
from eduvpn_common.event import EventHandler
@@ -7,30 +8,6 @@ from eduvpn_common.types import VPNStateChange, encode_args, decode_res, get_dat
from eduvpn_common.server import get_servers
from eduvpn_common.state import State, StateType
-eduvpn_objects = {}
-
-
-def add_as_global_object(eduvpn) -> bool:
- global eduvpn_objects
- if eduvpn.name not in eduvpn_objects:
- eduvpn_objects[eduvpn.name] = eduvpn
- return True
- return False
-
-
-def remove_as_global_object(eduvpn):
- global eduvpn_objects
- eduvpn_objects.pop(eduvpn.name, None)
-
-
-@VPNStateChange
-def state_callback(name, old_state, new_state, data):
- name = name.decode()
- if name not in eduvpn_objects:
- return
- eduvpn_objects[name].callback(State(old_state), State(new_state), data)
-
-
class EduVPN(object):
def __init__(self, name: str, config_directory: str, language: str):
self.name = name
@@ -60,17 +37,14 @@ class EduVPN(object):
if self.location_event:
self.location_event.wait()
- def go_function(self, func, *args):
+ def go_function(self, func: Any, *args, decode_func: Optional[Callable] = None):
# The functions all have at least one arg type which is the name of the client
args_gen = encode_args(list(args), func.argtypes[1:])
res = func(self.name.encode("utf-8"), *(args_gen))
- return decode_res(func.restype)(self.lib, res)
-
- def go_function_custom_decode(self, func, decode_func, *args):
- # The functions all have at least one arg type which is the name of the client
- args_gen = encode_args(list(args), func.argtypes[1:])
- res = func(self.name.encode("utf-8"), *(args_gen))
- return decode_func(self.lib, res)
+ if decode_func is None:
+ return decode_res(func.restype)(self.lib, res)
+ else:
+ return decode_func(self.lib, res)
def cancel_oauth(self) -> None:
cancel_oauth_err = self.go_function(self.lib.CancelOAuth)
@@ -94,7 +68,7 @@ class EduVPN(object):
raise register_err
def get_disco_servers(self) -> str:
- servers, servers_err = self.go_function_custom_decode(
+ servers, servers_err = self.go_function(
self.lib.GetDiscoServers,
decode_func=lambda lib, x: get_data_error(lib, x, get_disco_servers),
)
@@ -105,7 +79,7 @@ class EduVPN(object):
return servers
def get_disco_organizations(self) -> str:
- organizations, organizations_err = self.go_function_custom_decode(
+ organizations, organizations_err = self.go_function(
self.lib.GetDiscoOrganizations,
decode_func=lambda lib, x: get_data_error(lib, x, get_disco_organizations),
)
@@ -152,7 +126,7 @@ class EduVPN(object):
if remove_err:
raise remove_err
- def get_config(self, url: str, func: callable, prefer_tcp: bool = False):
+ def get_config(self, url: str, func: Any, prefer_tcp: bool = False):
# Because it could be the case that a profile callback is started, store a threading event
# In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set
# The event is set in self.set_profile
@@ -205,7 +179,7 @@ class EduVPN(object):
if connecting_err:
raise connecting_err
- def set_disconnected(self, cleanup=True) -> None:
+ def set_disconnected(self, cleanup: bool = True) -> None:
disconnect_err = self.go_function(self.lib.SetDisconnected, cleanup)
if disconnect_err:
@@ -217,17 +191,17 @@ class EduVPN(object):
if search_err:
raise search_err
- def remove_class_callbacks(self, cls) -> None:
+ def remove_class_callbacks(self, cls: Any) -> None:
self.event_handler.change_class_callbacks(cls, add=False)
- def register_class_callbacks(self, cls) -> None:
+ def register_class_callbacks(self, cls: Any) -> None:
self.event_handler.change_class_callbacks(cls)
@property
def event(self) -> EventHandler:
return self.event_handler
- def callback(self, old_state: State, new_state: State, data) -> None:
+ def callback(self, old_state: State, new_state: State, data: Any) -> None:
self.event.run(old_state, new_state, data)
def set_profile(self, profile_id: str) -> None:
@@ -275,7 +249,7 @@ class EduVPN(object):
return self.go_function(self.lib.InFSMState, state_id)
def get_saved_servers(self):
- servers, servers_err = self.go_function_custom_decode(
+ servers, servers_err = self.go_function(
self.lib.GetSavedServers,
decode_func=lambda lib, x: get_data_error(lib, x, get_servers),
)
@@ -284,3 +258,28 @@ class EduVPN(object):
raise servers_err
return servers
+
+eduvpn_objects: Dict[str, EduVPN] = {}
+
+
+@VPNStateChange
+def state_callback(name: bytes, old_state: int, new_state: int, data: Any):
+ name_decoded = name.decode()
+ if name_decoded not in eduvpn_objects:
+ return
+ eduvpn_objects[name_decoded].callback(State(old_state), State(new_state), data)
+
+
+
+def add_as_global_object(eduvpn: EduVPN) -> bool:
+ global eduvpn_objects
+ if eduvpn.name not in eduvpn_objects:
+ eduvpn_objects[eduvpn.name] = eduvpn
+ return True
+ return False
+
+
+def remove_as_global_object(eduvpn: EduVPN):
+ global eduvpn_objects
+ eduvpn_objects.pop(eduvpn.name, None)
+
diff --git a/wrappers/python/eduvpn_common/server.py b/wrappers/python/eduvpn_common/server.py
index 71b6487..01b5204 100644
--- a/wrappers/python/eduvpn_common/server.py
+++ b/wrappers/python/eduvpn_common/server.py
@@ -1,10 +1,11 @@
+from typing import List, Optional, Type
from eduvpn_common.types import cServer, cServers, cServerLocations, cServerProfiles
-from ctypes import cast, POINTER
+from ctypes import c_void_p, cast, POINTER, CDLL
from datetime import datetime
class Profile:
- def __init__(self, identifier, display_name, default_gateway: bool):
+ def __init__(self, identifier: str, display_name: str, default_gateway: bool):
self.identifier = identifier
self.display_name = display_name
self.default_gateway = default_gateway
@@ -14,19 +15,19 @@ class Profile:
class Profiles:
- def __init__(self, profiles, current):
+ def __init__(self, profiles: List[Profile], current: int):
self.profiles = profiles
self.current_index = current
@property
- def current(self):
+ def current(self) -> Optional[Profile]:
if self.current_index < len(self.profiles):
return self.profiles[self.current_index]
return None
class Server:
- def __init__(self, url, display_name, profiles=None, expire_time=0):
+ def __init__(self, url: str, display_name: str, profiles: Optional[Profiles] = None, expire_time: int = 0):
self.url = url
self.display_name = display_name
self.profiles = profiles
@@ -36,29 +37,28 @@ class Server:
return self.display_name
@property
- def category(self):
+ def category(self) -> str:
return "Custom Server"
class InstituteServer(Server):
- def __init__(self, url, display_name, support_contact, profiles, expire_time):
+ def __init__(self, url: str, display_name: str, support_contact: List[str], profiles: Profiles, expire_time: int):
super().__init__(url, display_name, profiles, expire_time)
self.support_contact = support_contact
@property
- def category(self):
+ def category(self) -> str:
return "Institute Access Server"
-
class SecureInternetServer(Server):
def __init__(
self,
- org_id,
- display_name,
- support_contact,
- profiles,
- expire_time,
- country_code,
+ org_id: str,
+ display_name: str,
+ support_contact: List[str],
+ profiles: Profiles,
+ expire_time: int,
+ country_code: str,
):
super().__init__(org_id, display_name, profiles, expire_time)
self.org_id = org_id
@@ -66,11 +66,11 @@ class SecureInternetServer(Server):
self.country_code = country_code
@property
- def category(self):
+ def category(self) -> str:
return "Secure Internet Server"
-def get_type_for_str(type_str: str):
+def get_type_for_str(type_str: str) -> Type[Server]:
if type_str == "secure_internet":
return SecureInternetServer
if type_str == "custom_server":
@@ -78,14 +78,14 @@ def get_type_for_str(type_str: str):
return InstituteServer
-def get_profiles(ptr):
+def get_profiles(ptr) -> Optional[Profiles]:
if not ptr:
- return []
+ return None
profiles = []
_profiles = ptr.contents
current_profile = _profiles.current
if not _profiles.profiles:
- return []
+ return None
for i in range(_profiles.total_profiles):
if not _profiles.profiles[i]:
continue
@@ -100,7 +100,7 @@ def get_profiles(ptr):
return Profiles(profiles, current_profile)
-def get_server(ptr, _type=None):
+def get_server(ptr, _type=None) -> Optional[Server]:
if not ptr:
return None
@@ -116,6 +116,8 @@ def get_server(ptr, _type=None):
for i in range(current_server.total_support_contact):
support_contact.append(current_server.support_contact[i].decode("utf-8"))
profiles = get_profiles(current_server.profiles)
+ if profiles is None:
+ return None
if _type is SecureInternetServer:
return SecureInternetServer(
identifier,
@@ -136,19 +138,19 @@ def get_server(ptr, _type=None):
return Server(identifier, display_name, profiles, current_server.expire_time)
-def get_transition_server(lib, ptr):
+def get_transition_server(lib: CDLL, ptr: c_void_p) -> Optional[Server]:
server = get_server(cast(ptr, POINTER(cServer)))
lib.FreeServer(ptr)
return server
-def get_transition_profiles(lib, ptr):
+def get_transition_profiles(lib: CDLL, ptr: c_void_p) -> Optional[Profiles]:
profiles = get_profiles(cast(ptr, POINTER(cServerProfiles)))
lib.FreeProfiles(ptr)
return profiles
-def get_servers(lib, ptr):
+def get_servers(lib: CDLL, ptr: c_void_p) -> Optional[List[Server]]:
if ptr:
returned = []
servers = cast(ptr, POINTER(cServers)).contents
@@ -175,7 +177,7 @@ def get_servers(lib, ptr):
return None
-def get_locations(lib, ptr):
+def get_locations(lib: CDLL, ptr: c_void_p) -> Optional[List[str]]:
if ptr:
locations = cast(ptr, POINTER(cServerLocations)).contents
location_list = []
diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py
index c543989..82987c0 100644
--- a/wrappers/python/eduvpn_common/types.py
+++ b/wrappers/python/eduvpn_common/types.py
@@ -1,6 +1,6 @@
-from ctypes import *
+from ctypes import Structure, c_int, c_char_p, c_size_t, c_ulonglong, c_void_p, CFUNCTYPE, POINTER, CDLL, cast, pointer
from eduvpn_common.error import ErrorLevel, WrappedError
-from typing import List, Optional, Tuple
+from typing import Any, Callable, Iterator, List, Optional, Tuple
class cError(Structure):
@@ -105,7 +105,7 @@ class ConfigError(Structure):
VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p)
-def encode_args(args, types):
+def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]:
for arg, t in zip(args, types):
# c_char_p needs the str to be encoded to bytes
if t is c_char_p:
@@ -113,17 +113,17 @@ def encode_args(args, types):
yield arg
-def decode_res(t):
+def decode_res(res: Any):
decode_map = {
c_int: get_bool,
c_void_p: get_error,
DataError: get_data_error,
ConfigError: get_config_error,
}
- return decode_map.get(t, lambda lib, x: x)
+ return decode_map.get(res, lambda lib, x: x)
-def get_ptr_string(lib, ptr: c_void_p) -> str:
+def get_ptr_string(lib: CDLL, ptr: c_void_p) -> str:
if ptr:
string = cast(ptr, c_char_p).value
lib.FreeString(ptr)
@@ -133,17 +133,17 @@ def get_ptr_string(lib, ptr: c_void_p) -> str:
def get_ptr_list_strings(
- lib, strings: POINTER(c_char_p), total_strings: c_size_t
+ lib: CDLL, strings: pointer, total_strings: int
) -> List[str]:
if strings:
strings_list = []
- for i in range(int(total_strings)):
+ for i in range(total_strings):
strings_list.append(strings[i].decode("utf-8"))
return strings_list
return []
-def get_error(lib, ptr: c_void_p) -> Optional[WrappedError]:
+def get_error(lib: CDLL, ptr: c_void_p) -> Optional[WrappedError]:
if not ptr:
return None
err = cast(ptr, POINTER(cError)).contents
@@ -155,7 +155,7 @@ def get_error(lib, ptr: c_void_p) -> Optional[WrappedError]:
def get_config_error(
- lib, config_error: ConfigError
+ lib: CDLL, config_error: ConfigError
) -> Tuple[str, str, Optional[WrappedError]]:
config = get_ptr_string(lib, config_error.config)
config_type = get_ptr_string(lib, config_error.config_type)
@@ -164,12 +164,12 @@ def get_config_error(
def get_data_error(
- lib, data_error: DataError, data_conv=get_ptr_string
+ lib: CDLL, data_error: DataError, data_conv: Callable = get_ptr_string
) -> Tuple[str, Optional[WrappedError]]:
data = data_conv(lib, data_error.data)
error = get_error(lib, data_error.error)
return data, error
-def get_bool(lib, boolInt: c_int) -> bool:
+def get_bool(lib: CDLL, boolInt: c_int) -> bool:
return boolInt == 1