summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-26 17:36:30 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-27 10:53:37 +0200
commit09ec69dfdef409868f1cb39cb8cc4b33c8690c9f (patch)
tree109925dbbee4a9120211897582760f96010ae8f2
parent0a19c2dedcaaa177b420eac99149515d84508204 (diff)
Python: Reformat and move most loading out of __init__
-rw-r--r--wrappers/python/eduvpn_common/__init__.py270
-rw-r--r--wrappers/python/eduvpn_common/discovery.py18
-rw-r--r--wrappers/python/eduvpn_common/error.py3
-rw-r--r--wrappers/python/eduvpn_common/event.py19
-rw-r--r--wrappers/python/eduvpn_common/loader.py113
-rw-r--r--wrappers/python/eduvpn_common/main.py73
-rw-r--r--wrappers/python/eduvpn_common/server.py12
-rw-r--r--wrappers/python/eduvpn_common/types.py175
-rw-r--r--wrappers/python/main.py8
9 files changed, 362 insertions, 329 deletions
diff --git a/wrappers/python/eduvpn_common/__init__.py b/wrappers/python/eduvpn_common/__init__.py
index 1406fa2..e69de29 100644
--- a/wrappers/python/eduvpn_common/__init__.py
+++ b/wrappers/python/eduvpn_common/__init__.py
@@ -1,270 +0,0 @@
-from ctypes import *
-from collections import defaultdict
-import pathlib
-import platform
-from typing import Tuple, Optional
-from typing import List
-from eduvpn_common.error import WrappedError, ErrorLevel
-
-_lib_prefixes = defaultdict(
- lambda: "lib",
- {
- "windows": "",
- },
-)
-
-_lib_suffixes = defaultdict(
- lambda: ".so",
- {
- "windows": ".dll",
- "darwin": ".dylib",
- },
-)
-
-_os = platform.system().lower()
-
-_libname = "eduvpn_common"
-_libfile = f"{_lib_prefixes[_os]}{_libname}{_lib_suffixes[_os]}"
-
-lib = None
-
-# Try to load in the normal path
-try:
- lib = cdll.LoadLibrary(_libfile)
-# Otherwise, library should have been copied to the lib/ folder
-except:
- lib = cdll.LoadLibrary(str(pathlib.Path(__file__).parent / "lib" / _libfile))
-
-
-class cError(Structure):
- _fields_ = [
- ("level", c_int),
- ("traceback", c_char_p),
- ("cause", c_char_p),
- ]
-
-class cServerLocations(Structure):
- _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)]
-
-
-class cDiscoveryOrganization(Structure):
- _fields_ = [
- ("display_name", c_char_p),
- ("org_id", c_char_p),
- ("secure_internet_home", c_char_p),
- ("keyword_list", c_char_p),
- ]
-
-
-class cDiscoveryOrganizations(Structure):
- _fields_ = [
- ("version", c_ulonglong),
- ("organizations", POINTER(POINTER(cDiscoveryOrganization))),
- ("total_organizations", c_size_t),
- ]
-
-
-class cDiscoveryServer(Structure):
- _fields_ = [
- ("authentication_url_template", c_char_p),
- ("base_url", c_char_p),
- ("country_code", c_char_p),
- ("display_name", c_char_p),
- ("keyword_list", c_char_p),
- ("public_key_list", POINTER(c_char_p)),
- ("total_public_keys", c_size_t),
- ("server_type", c_char_p),
- ("support_contact", POINTER(c_char_p)),
- ("total_support_contact", c_size_t),
- ]
-
-
-class cDiscoveryServers(Structure):
- _fields_ = [
- ("version", c_ulonglong),
- ("servers", POINTER(POINTER(cDiscoveryServer))),
- ("total_servers", c_size_t),
- ]
-
-
-class cServerProfile(Structure):
- _fields_ = [
- ("identifier", c_char_p),
- ("display_name", c_char_p),
- ("default_gateway", c_int),
- ]
-
-
-class cServerProfiles(Structure):
- _fields_ = [
- ("current", c_int),
- ("profiles", POINTER(POINTER(cServerProfile))),
- ("total_profiles", c_size_t),
- ]
-
-
-class cServer(Structure):
- _fields_ = [
- ("identifier", c_char_p),
- ("display_name", c_char_p),
- ("server_type", c_char_p),
- ("country_code", c_char_p),
- ("support_contact", POINTER(c_char_p)),
- ("total_support_contact", c_size_t),
- ("profiles", POINTER(cServerProfiles)),
- ("expire_time", c_ulonglong),
- ]
-
-
-class cServers(Structure):
- _fields_ = [
- ("custom_servers", POINTER(POINTER(cServer))),
- ("total_custom", c_size_t),
- ("institute_servers", POINTER(POINTER(cServer))),
- ("total_institute", c_size_t),
- ("secure_internet", POINTER(cServer)),
- ]
-
-
-class DataError(Structure):
- _fields_ = [("data", c_void_p), ("error", c_void_p)]
-
-
-class ConfigError(Structure):
- _fields_ = [("config", c_void_p), ("config_type", c_void_p), ("error", c_void_p)]
-
-
-VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p)
-
-# Exposed functions
-# We have to use c_void_p instead of c_char_p to free it properly
-# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error
-lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [
- c_char_p
-], c_void_p
-lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [
- c_char_p,
- c_char_p,
-], c_void_p
-lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [
- c_char_p,
- c_char_p,
-], c_void_p
-lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [
- c_char_p,
- c_char_p,
- c_int,
-], ConfigError
-lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [
- c_char_p,
- c_char_p,
- c_int,
-], ConfigError
-lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [
- c_char_p,
- c_char_p,
- c_int,
-], ConfigError
-lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None
-lib.Register.argtypes, lib.Register.restype = [
- c_char_p,
- c_char_p,
- VPNStateChange,
- c_int,
-], c_void_p
-lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [
- c_char_p
-], DataError
-lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError
-lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None
-lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p
-lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p
-lib.ChangeSecureLocation.argtypes, lib.ChangeSecureLocation.restype = [
- c_char_p
-], c_void_p
-lib.SetSecureLocation.argtypes, lib.SetSecureLocation.restype = [
- c_char_p,
- c_char_p,
-], c_void_p
-lib.SetConnected.argtypes, lib.SetConnected.restype = [c_char_p], c_void_p
-lib.SetDisconnecting.argtypes, lib.SetDisconnecting.restype = [c_char_p], c_void_p
-lib.SetConnecting.argtypes, lib.SetConnecting.restype = [c_char_p], c_void_p
-lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c_void_p
-lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p
-lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int
-lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p
-lib.FreeProfiles.argtypes, lib.FreeProfiles.restype = [c_void_p], None
-lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None
-lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
-lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [
- c_void_p
-], None
-lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None
-lib.FreeError.argtypes, lib.FreeError.restype = [c_void_p], None
-lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None
-lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None
-lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int
-lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], DataError
-
-
-def encode_args(args, types):
- for arg, t in zip(args, types):
- # c_char_p needs the str to be encoded to bytes
- if t is c_char_p:
- arg = arg.encode("utf-8")
- yield arg
-
-
-def decode_res(t):
- return decode_map.get(t, lambda x: x)
-
-
-def get_ptr_string(ptr: c_void_p) -> str:
- if ptr:
- string = cast(ptr, c_char_p).value
- lib.FreeString(ptr)
- if string:
- return string.decode("utf-8")
- return ""
-
-
-def get_ptr_list_strings(
- strings: POINTER(c_char_p), total_strings: c_size_t
-) -> List[str]:
- if strings:
- strings_list = []
- for i in range(total_strings):
- strings_list.append(strings[i].decode("utf-8"))
- return strings_list
- return []
-
-def get_error(ptr: c_void_p) -> Optional[WrappedError]:
- if not ptr:
- return None
- err = cast(ptr, POINTER(cError)).contents
- wrapped = WrappedError(err.traceback.decode("utf-8"), err.cause.decode("utf-8"), ErrorLevel(err.level))
- lib.FreeError(ptr)
- return wrapped
-
-def get_config_error(config_error: ConfigError) -> Tuple[str, str, Optional[WrappedError]]:
- config = get_ptr_string(config_error.config)
- config_type = get_ptr_string(config_error.config_type)
- err = get_error(config_error.error)
- return config, config_type, err
-
-def get_data_error(data_error: DataError, data_conv=get_ptr_string) -> Tuple[str, Optional[WrappedError]]:
- data = data_conv(data_error.data)
- error = get_error(data_error.error)
- return data, error
-
-
-def get_bool(boolInt: c_int) -> bool:
- return boolInt == 1
-
-
-decode_map = {
- c_int: get_bool,
- c_void_p: get_error,
- DataError: get_data_error,
- ConfigError: get_config_error,
-}
diff --git a/wrappers/python/eduvpn_common/discovery.py b/wrappers/python/eduvpn_common/discovery.py
index 68741bc..3a1cbe6 100644
--- a/wrappers/python/eduvpn_common/discovery.py
+++ b/wrappers/python/eduvpn_common/discovery.py
@@ -1,4 +1,8 @@
-from eduvpn_common import lib, cDiscoveryOrganizations, cDiscoveryServers, get_ptr_list_strings
+from eduvpn_common.types import (
+ cDiscoveryOrganizations,
+ cDiscoveryServers,
+ get_ptr_list_strings,
+)
from ctypes import cast, POINTER
@@ -62,7 +66,7 @@ def get_disco_organization(ptr):
return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list)
-def get_disco_server(ptr):
+def get_disco_server(lib, ptr):
if not ptr:
return None
@@ -75,11 +79,11 @@ def get_disco_server(ptr):
display_name = current_server.display_name.decode("utf-8")
keyword_list = current_server.keyword_list.decode("utf-8")
public_keys = get_ptr_list_strings(
- current_server.public_key_list, current_server.total_public_keys
+ lib, current_server.public_key_list, current_server.total_public_keys
)
server_type = current_server.server_type.decode("utf-8")
support_contacts = get_ptr_list_strings(
- current_server.support_contact, current_server.total_support_contact
+ lib, current_server.support_contact, current_server.total_support_contact
)
return DiscoServer(
authentication_url_template,
@@ -93,7 +97,7 @@ def get_disco_server(ptr):
)
-def get_disco_servers(ptr):
+def get_disco_servers(lib, ptr):
if ptr:
svrs = cast(ptr, POINTER(cDiscoveryServers)).contents
@@ -101,7 +105,7 @@ def get_disco_servers(ptr):
if svrs.servers:
for i in range(svrs.total_servers):
- current = get_disco_server(svrs.servers[i])
+ current = get_disco_server(lib, svrs.servers[i])
if current is None:
continue
@@ -111,7 +115,7 @@ def get_disco_servers(ptr):
return None
-def get_disco_organizations(ptr):
+def get_disco_organizations(lib, ptr):
if ptr:
orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents
organizations = []
diff --git a/wrappers/python/eduvpn_common/error.py b/wrappers/python/eduvpn_common/error.py
index 50298bb..a5b59b4 100644
--- a/wrappers/python/eduvpn_common/error.py
+++ b/wrappers/python/eduvpn_common/error.py
@@ -1,15 +1,16 @@
from enum import Enum
+
class ErrorLevel(Enum):
ERR_OTHER = 0
ERR_INFO = 1
ERR_WARNING = 2
ERR_FATAL = 3
+
class WrappedError(Exception):
def __init__(self, traceback: str, cause: str, level: ErrorLevel):
super(WrappedError, self).__init__(cause)
self.traceback = traceback
self.cause = cause
self.level = level
-
diff --git a/wrappers/python/eduvpn_common/event.py b/wrappers/python/eduvpn_common/event.py
index 4532bef..a47a0a7 100644
--- a/wrappers/python/eduvpn_common/event.py
+++ b/wrappers/python/eduvpn_common/event.py
@@ -1,4 +1,3 @@
-from eduvpn_common import VPNStateChange, get_ptr_string
from enum import Enum
from typing import Callable
from eduvpn_common.state import State, StateType
@@ -8,6 +7,7 @@ from eduvpn_common.server import (
get_transition_server,
get_servers,
)
+from eduvpn_common.types import get_ptr_string
EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback"
@@ -22,29 +22,30 @@ def class_state_transition(state: int, state_type: StateType) -> Callable:
return wrapper
-def convert_data(state: State, data):
+def convert_data(lib, state: State, data):
if not data:
return None
if state is State.NO_SERVER:
- return get_servers(data)
+ return get_servers(lib, data)
if state is State.OAUTH_STARTED:
- return get_ptr_string(data)
+ return get_ptr_string(lib, data)
if state is State.ASK_LOCATION:
- return get_locations(data)
+ return get_locations(lib, data)
if state is State.ASK_PROFILE:
- return get_transition_profiles(data)
+ return get_transition_profiles(lib, data)
if state in [
State.DISCONNECTED,
State.DISCONNECTING,
State.CONNECTING,
State.CONNECTED,
]:
- return get_transition_server(data)
+ return get_transition_server(lib, data)
class EventHandler(object):
- def __init__(self):
+ def __init__(self, lib):
self.handlers = {}
+ self.lib = lib
def change_class_callbacks(self, cls, add=True) -> None:
# Loop over method names
@@ -103,7 +104,7 @@ class EventHandler(object):
# The state is done when the wait event finishes
converted = data
if convert:
- converted = convert_data(new_state, data)
+ converted = convert_data(self.lib, new_state, data)
self.run_state(old_state, new_state, StateType.Leave, converted)
self.run_state(new_state, old_state, StateType.Enter, converted)
self.run_state(new_state, old_state, StateType.Wait, converted)
diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py
new file mode 100644
index 0000000..bce2638
--- /dev/null
+++ b/wrappers/python/eduvpn_common/loader.py
@@ -0,0 +1,113 @@
+from ctypes import *
+from collections import defaultdict
+import pathlib
+import platform
+from eduvpn_common.types import *
+
+
+def load_lib(version: str):
+ lib_prefixes = defaultdict(
+ lambda: "lib",
+ {
+ "windows": "",
+ },
+ )
+
+ lib_suffixes = defaultdict(
+ lambda: ".so",
+ {
+ "windows": ".dll",
+ "darwin": ".dylib",
+ },
+ )
+
+ os = platform.system().lower()
+
+ libname = "eduvpn_common"
+ libfile = f"{lib_prefixes[os]}{libname}{lib_suffixes[os]}"
+
+ lib = None
+
+ # Try to load in the normal path
+ try:
+ lib = cdll.LoadLibrary(libfile)
+ # Otherwise, library should have been copied to the lib/ folder
+ except:
+ lib = cdll.LoadLibrary(str(pathlib.Path(__file__).parent / "lib" / libfile))
+
+ return lib
+
+
+def initialize_functions(lib):
+ # Exposed functions
+ # We have to use c_void_p instead of c_char_p to free it properly
+ # See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error
+ lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p
+ lib.ChangeSecureLocation.argtypes, lib.ChangeSecureLocation.restype = [
+ c_char_p
+ ], c_void_p
+ lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None
+ lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [
+ c_void_p
+ ], None
+ lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None
+ lib.FreeError.argtypes, lib.FreeError.restype = [c_void_p], None
+ lib.FreeProfiles.argtypes, lib.FreeProfiles.restype = [c_void_p], None
+ lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None
+ lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None
+ lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None
+ lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
+ lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [
+ c_char_p,
+ c_char_p,
+ c_int,
+ ], ConfigError
+ lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [
+ c_char_p,
+ c_char_p,
+ c_int,
+ ], ConfigError
+ lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [
+ c_char_p,
+ c_char_p,
+ c_int,
+ ], ConfigError
+ lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [
+ c_char_p
+ ], DataError
+ lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError
+ lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], DataError
+ lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None
+ lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int
+ lib.Register.argtypes, lib.Register.restype = [
+ c_char_p,
+ c_char_p,
+ VPNStateChange,
+ c_int,
+ ], c_void_p
+ lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [
+ c_char_p,
+ c_char_p,
+ ], c_void_p
+ lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [
+ c_char_p,
+ c_char_p,
+ ], c_void_p
+ lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [
+ c_char_p
+ ], c_void_p
+ lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p
+ lib.SetConnected.argtypes, lib.SetConnected.restype = [c_char_p], c_void_p
+ lib.SetConnecting.argtypes, lib.SetConnecting.restype = [c_char_p], c_void_p
+ lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [
+ c_char_p,
+ c_int,
+ ], c_void_p
+ lib.SetDisconnecting.argtypes, lib.SetDisconnecting.restype = [c_char_p], c_void_p
+ lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p
+ lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p
+ lib.SetSecureLocation.argtypes, lib.SetSecureLocation.restype = [
+ c_char_p,
+ c_char_p,
+ ], c_void_p
+ lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int
diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py
index 3875ad9..1b18fb0 100644
--- a/wrappers/python/eduvpn_common/main.py
+++ b/wrappers/python/eduvpn_common/main.py
@@ -1,13 +1,16 @@
-from eduvpn_common import lib, VPNStateChange, encode_args, decode_res, get_data_error
from typing import Optional, Tuple
import threading
from eduvpn_common.discovery import get_disco_organizations, get_disco_servers
from eduvpn_common.event import EventHandler
-from eduvpn_common.state import State, StateType
+from eduvpn_common.loader import initialize_functions, load_lib
+from eduvpn_common.types import VPNStateChange, encode_args, decode_res, get_data_error
from eduvpn_common.server import get_servers
+from eduvpn_common.state import State, StateType
eduvpn_objects = {}
+VERSION = "0.1.0"
+
def add_as_global_object(eduvpn) -> bool:
global eduvpn_objects
@@ -32,10 +35,15 @@ def state_callback(name, old_state, new_state, data):
class EduVPN(object):
def __init__(self, name: str, config_directory: str):
- self.event_handler = EventHandler()
self.name = name
self.config_directory = config_directory
+ # Load the library
+ self.lib = load_lib(VERSION)
+ initialize_functions(self.lib)
+
+ self.event_handler = EventHandler(self.lib)
+
# Callbacks that need to wait for specific events
# The ask profile callback needs to wait for the UI thread to select a profile
@@ -57,22 +65,22 @@ class EduVPN(object):
# The functions all have at least one arg type which is the name of the client
args_gen = encode_args(list(args), func.argtypes[1:])
res = func(self.name.encode("utf-8"), *(args_gen))
- return decode_res(func.restype)(res)
+ return decode_res(func.restype)(self.lib, res)
def go_function_custom_decode(self, func, decode_func, *args):
# The functions all have at least one arg type which is the name of the client
args_gen = encode_args(list(args), func.argtypes[1:])
res = func(self.name.encode("utf-8"), *(args_gen))
- return decode_func(res)
+ return decode_func(self.lib, res)
def cancel_oauth(self) -> None:
- cancel_oauth_err = self.go_function(lib.CancelOAuth)
+ cancel_oauth_err = self.go_function(self.lib.CancelOAuth)
if cancel_oauth_err:
raise cancel_oauth_err
def deregister(self) -> None:
- self.go_function(lib.Deregister)
+ self.go_function(self.lib.Deregister)
remove_as_global_object(self)
def register(self, debug: bool = False) -> None:
@@ -80,7 +88,7 @@ class EduVPN(object):
raise Exception("Already registered")
register_err = self.go_function(
- lib.Register, self.config_directory, state_callback, debug
+ self.lib.Register, self.config_directory, state_callback, debug
)
if register_err:
@@ -88,38 +96,40 @@ class EduVPN(object):
def get_disco_servers(self) -> str:
servers, servers_err = self.go_function_custom_decode(
- lib.GetDiscoServers, decode_func=lambda x: get_data_error(x, get_disco_servers)
+ self.lib.GetDiscoServers,
+ decode_func=lambda lib, x: get_data_error(lib, x, get_disco_servers),
)
if servers_err:
- raise servers_err
+ raise servers_err
return servers
def get_disco_organizations(self) -> str:
organizations, organizations_err = self.go_function_custom_decode(
- lib.GetDiscoOrganizations, decode_func=lambda x: get_data_error(x, get_disco_organizations)
+ self.lib.GetDiscoOrganizations,
+ decode_func=lambda lib, x: get_data_error(lib, x, get_disco_organizations),
)
if organizations_err:
- raise organizations_err
+ raise organizations_err
return organizations
def remove_secure_internet(self):
- remove_err = self.go_function(lib.RemoveSecureInternet)
+ remove_err = self.go_function(self.lib.RemoveSecureInternet)
if remove_err:
raise remove_err
def remove_institute_access(self, url: str):
- remove_err = self.go_function(lib.RemoveInstituteAccess, url)
+ remove_err = self.go_function(self.lib.RemoveInstituteAccess, url)
if remove_err:
raise remove_err
def remove_custom_server(self, url: str):
- remove_err = self.go_function(lib.RemoveCustomServer, url)
+ remove_err = self.go_function(self.lib.RemoveCustomServer, url)
if remove_err:
raise remove_err
@@ -143,49 +153,49 @@ class EduVPN(object):
def get_config_custom_server(
self, url: str, force_tcp: bool = False
) -> Tuple[str, str]:
- return self.get_config(url, lib.GetConfigCustomServer, force_tcp)
+ return self.get_config(url, self.lib.GetConfigCustomServer, force_tcp)
def get_config_institute_access(
self, url: str, force_tcp: bool = False
) -> Tuple[str, str]:
- return self.get_config(url, lib.GetConfigInstituteAccess, force_tcp)
+ return self.get_config(url, self.lib.GetConfigInstituteAccess, force_tcp)
def get_config_secure_internet(
self, url: str, force_tcp: bool = False
) -> Tuple[str, str]:
self.location_event = threading.Event()
- return self.get_config(url, lib.GetConfigSecureInternet, force_tcp)
+ return self.get_config(url, self.lib.GetConfigSecureInternet, force_tcp)
def go_back(self) -> None:
# Ignore the error
- self.go_function(lib.GoBack)
+ self.go_function(self.lib.GoBack)
def set_connected(self) -> None:
- connect_err = self.go_function(lib.SetConnected)
+ connect_err = self.go_function(self.lib.SetConnected)
if connect_err:
raise connect_err
def set_disconnecting(self) -> None:
- disconnecting_err = self.go_function(lib.SetDisconnecting)
+ disconnecting_err = self.go_function(self.lib.SetDisconnecting)
if disconnecting_err:
raise disconnecting_err
def set_connecting(self) -> None:
- connecting_err = self.go_function(lib.SetConnecting)
+ connecting_err = self.go_function(self.lib.SetConnecting)
if connecting_err:
raise connecting_err
def set_disconnected(self, cleanup=True) -> None:
- disconnect_err = self.go_function(lib.SetDisconnected, cleanup)
+ disconnect_err = self.go_function(self.lib.SetDisconnected, cleanup)
if disconnect_err:
raise disconnect_err
def set_search_server(self) -> None:
- search_err = self.go_function(lib.SetSearchServer)
+ search_err = self.go_function(self.lib.SetSearchServer)
if search_err:
raise search_err
@@ -205,7 +215,7 @@ class EduVPN(object):
def set_profile(self, profile_id: str) -> None:
# Set the profile id
- profile_err = self.go_function(lib.SetProfileID, profile_id)
+ profile_err = self.go_function(self.lib.SetProfileID, profile_id)
# If there is a profile event, set it so that the wait callback finishes
# And so that the Go code can move to the next state
@@ -218,14 +228,14 @@ class EduVPN(object):
def change_secure_location(self) -> None:
# Set the location by country code
self.location_event = threading.Event()
- location_err = self.go_function(lib.ChangeSecureLocation)
+ location_err = self.go_function(self.lib.ChangeSecureLocation)
if location_err:
raise location_err
def set_secure_location(self, country_code: str) -> None:
# Set the location by country code
- location_err = self.go_function(lib.SetSecureLocation, country_code)
+ location_err = self.go_function(self.lib.SetSecureLocation, country_code)
# If there is a location event, set it so that the wait callback finishes
# And so that the Go code can move to the next state
@@ -236,18 +246,19 @@ class EduVPN(object):
raise location_err
def renew_session(self) -> None:
- renew_err = self.go_function(lib.RenewSession)
+ renew_err = self.go_function(self.lib.RenewSession)
if renew_err:
raise renew_err
def should_renew_button(self) -> bool:
- return self.go_function(lib.ShouldRenewButton)
+ return self.go_function(self.lib.ShouldRenewButton)
def in_fsm_state(self, state_id: State) -> bool:
- return self.go_function(lib.InFSMState, state_id)
+ return self.go_function(self.lib.InFSMState, state_id)
def get_saved_servers(self) -> str:
return self.go_function_custom_decode(
- lib.GetSavedServers, decode_func=lambda x: get_data_error(x, get_servers)
+ self.lib.GetSavedServers,
+ decode_func=lambda lib, x: get_data_error(lib, x, get_servers),
)
diff --git a/wrappers/python/eduvpn_common/server.py b/wrappers/python/eduvpn_common/server.py
index 470f704..36c3643 100644
--- a/wrappers/python/eduvpn_common/server.py
+++ b/wrappers/python/eduvpn_common/server.py
@@ -1,5 +1,5 @@
-from eduvpn_common import lib, cServer, cServers, cServerLocations, cServerProfiles
-from ctypes import cast, POINTER, c_char_p
+from eduvpn_common.types import cServer, cServers, cServerLocations, cServerProfiles
+from ctypes import cast, POINTER
from datetime import datetime
@@ -125,19 +125,19 @@ def get_server(ptr, _type=None):
return Server(identifier, display_name, profiles, current_server.expire_time)
-def get_transition_server(ptr):
+def get_transition_server(lib, ptr):
server = get_server(cast(ptr, POINTER(cServer)))
lib.FreeServer(ptr)
return server
-def get_transition_profiles(ptr):
+def get_transition_profiles(lib, ptr):
profiles = get_profiles(cast(ptr, POINTER(cServerProfiles)))
lib.FreeProfiles(ptr)
return profiles
-def get_servers(ptr):
+def get_servers(lib, ptr):
if ptr:
returned = []
servers = cast(ptr, POINTER(cServers)).contents
@@ -164,7 +164,7 @@ def get_servers(ptr):
return None
-def get_locations(ptr):
+def get_locations(lib, ptr):
if ptr:
locations = cast(ptr, POINTER(cServerLocations)).contents
location_list = []
diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py
new file mode 100644
index 0000000..c543989
--- /dev/null
+++ b/wrappers/python/eduvpn_common/types.py
@@ -0,0 +1,175 @@
+from ctypes import *
+from eduvpn_common.error import ErrorLevel, WrappedError
+from typing import List, Optional, Tuple
+
+
+class cError(Structure):
+ _fields_ = [
+ ("level", c_int),
+ ("traceback", c_char_p),
+ ("cause", c_char_p),
+ ]
+
+
+class cServerLocations(Structure):
+ _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)]
+
+
+class cDiscoveryOrganization(Structure):
+ _fields_ = [
+ ("display_name", c_char_p),
+ ("org_id", c_char_p),
+ ("secure_internet_home", c_char_p),
+ ("keyword_list", c_char_p),
+ ]
+
+
+class cDiscoveryOrganizations(Structure):
+ _fields_ = [
+ ("version", c_ulonglong),
+ ("organizations", POINTER(POINTER(cDiscoveryOrganization))),
+ ("total_organizations", c_size_t),
+ ]
+
+
+class cDiscoveryServer(Structure):
+ _fields_ = [
+ ("authentication_url_template", c_char_p),
+ ("base_url", c_char_p),
+ ("country_code", c_char_p),
+ ("display_name", c_char_p),
+ ("keyword_list", c_char_p),
+ ("public_key_list", POINTER(c_char_p)),
+ ("total_public_keys", c_size_t),
+ ("server_type", c_char_p),
+ ("support_contact", POINTER(c_char_p)),
+ ("total_support_contact", c_size_t),
+ ]
+
+
+class cDiscoveryServers(Structure):
+ _fields_ = [
+ ("version", c_ulonglong),
+ ("servers", POINTER(POINTER(cDiscoveryServer))),
+ ("total_servers", c_size_t),
+ ]
+
+
+class cServerProfile(Structure):
+ _fields_ = [
+ ("identifier", c_char_p),
+ ("display_name", c_char_p),
+ ("default_gateway", c_int),
+ ]
+
+
+class cServerProfiles(Structure):
+ _fields_ = [
+ ("current", c_int),
+ ("profiles", POINTER(POINTER(cServerProfile))),
+ ("total_profiles", c_size_t),
+ ]
+
+
+class cServer(Structure):
+ _fields_ = [
+ ("identifier", c_char_p),
+ ("display_name", c_char_p),
+ ("server_type", c_char_p),
+ ("country_code", c_char_p),
+ ("support_contact", POINTER(c_char_p)),
+ ("total_support_contact", c_size_t),
+ ("profiles", POINTER(cServerProfiles)),
+ ("expire_time", c_ulonglong),
+ ]
+
+
+class cServers(Structure):
+ _fields_ = [
+ ("custom_servers", POINTER(POINTER(cServer))),
+ ("total_custom", c_size_t),
+ ("institute_servers", POINTER(POINTER(cServer))),
+ ("total_institute", c_size_t),
+ ("secure_internet", POINTER(cServer)),
+ ]
+
+
+class DataError(Structure):
+ _fields_ = [("data", c_void_p), ("error", c_void_p)]
+
+
+class ConfigError(Structure):
+ _fields_ = [("config", c_void_p), ("config_type", c_void_p), ("error", c_void_p)]
+
+
+VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p)
+
+
+def encode_args(args, types):
+ for arg, t in zip(args, types):
+ # c_char_p needs the str to be encoded to bytes
+ if t is c_char_p:
+ arg = arg.encode("utf-8")
+ yield arg
+
+
+def decode_res(t):
+ decode_map = {
+ c_int: get_bool,
+ c_void_p: get_error,
+ DataError: get_data_error,
+ ConfigError: get_config_error,
+ }
+ return decode_map.get(t, lambda lib, x: x)
+
+
+def get_ptr_string(lib, ptr: c_void_p) -> str:
+ if ptr:
+ string = cast(ptr, c_char_p).value
+ lib.FreeString(ptr)
+ if string:
+ return string.decode("utf-8")
+ return ""
+
+
+def get_ptr_list_strings(
+ lib, strings: POINTER(c_char_p), total_strings: c_size_t
+) -> List[str]:
+ if strings:
+ strings_list = []
+ for i in range(int(total_strings)):
+ strings_list.append(strings[i].decode("utf-8"))
+ return strings_list
+ return []
+
+
+def get_error(lib, ptr: c_void_p) -> Optional[WrappedError]:
+ if not ptr:
+ return None
+ err = cast(ptr, POINTER(cError)).contents
+ wrapped = WrappedError(
+ err.traceback.decode("utf-8"), err.cause.decode("utf-8"), ErrorLevel(err.level)
+ )
+ lib.FreeError(ptr)
+ return wrapped
+
+
+def get_config_error(
+ lib, config_error: ConfigError
+) -> Tuple[str, str, Optional[WrappedError]]:
+ config = get_ptr_string(lib, config_error.config)
+ config_type = get_ptr_string(lib, config_error.config_type)
+ err = get_error(lib, config_error.error)
+ return config, config_type, err
+
+
+def get_data_error(
+ lib, data_error: DataError, data_conv=get_ptr_string
+) -> Tuple[str, Optional[WrappedError]]:
+ data = data_conv(lib, data_error.data)
+ error = get_error(lib, data_error.error)
+ return data, error
+
+
+def get_bool(lib, boolInt: c_int) -> bool:
+ return boolInt == 1
diff --git a/wrappers/python/main.py b/wrappers/python/main.py
index 0bd2502..657f0ab 100644
--- a/wrappers/python/main.py
+++ b/wrappers/python/main.py
@@ -1,9 +1,7 @@
import eduvpn_common.main as eduvpn
from eduvpn_common.state import State, StateType
import webbrowser
-import json
import sys
-import time
from typing import List
# Asks the user for a profile index
@@ -78,9 +76,9 @@ if __name__ == "__main__":
except Exception as e:
print("Failed registering:", e)
- #server = input(
- # "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): "
- #)
+ server = input(
+ "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): "
+ )
# Get a Wireguard/OpenVPN config
try: