summaryrefslogtreecommitdiff
path: root/wrappers/python/eduvpn_common
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-26 16:47:35 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-26 16:48:22 +0200
commit0a19c2dedcaaa177b420eac99149515d84508204 (patch)
tree34631498e694895a398da32fd077d855d938b113 /wrappers/python/eduvpn_common
parent060d133acbc1d11cd79e32c6861956c265d87c7f (diff)
Python: Move from src/ to eduvpn_common/ and absolufy imports
Diffstat (limited to 'wrappers/python/eduvpn_common')
-rw-r--r--wrappers/python/eduvpn_common/__init__.py270
-rw-r--r--wrappers/python/eduvpn_common/discovery.py126
-rw-r--r--wrappers/python/eduvpn_common/error.py15
-rw-r--r--wrappers/python/eduvpn_common/event.py109
-rw-r--r--wrappers/python/eduvpn_common/main.py253
-rw-r--r--wrappers/python/eduvpn_common/server.py175
-rw-r--r--wrappers/python/eduvpn_common/state.py24
7 files changed, 972 insertions, 0 deletions
diff --git a/wrappers/python/eduvpn_common/__init__.py b/wrappers/python/eduvpn_common/__init__.py
new file mode 100644
index 0000000..1406fa2
--- /dev/null
+++ b/wrappers/python/eduvpn_common/__init__.py
@@ -0,0 +1,270 @@
+from ctypes import *
+from collections import defaultdict
+import pathlib
+import platform
+from typing import Tuple, Optional
+from typing import List
+from eduvpn_common.error import WrappedError, ErrorLevel
+
+_lib_prefixes = defaultdict(
+ lambda: "lib",
+ {
+ "windows": "",
+ },
+)
+
+_lib_suffixes = defaultdict(
+ lambda: ".so",
+ {
+ "windows": ".dll",
+ "darwin": ".dylib",
+ },
+)
+
+_os = platform.system().lower()
+
+_libname = "eduvpn_common"
+_libfile = f"{_lib_prefixes[_os]}{_libname}{_lib_suffixes[_os]}"
+
+lib = None
+
+# Try to load in the normal path
+try:
+ lib = cdll.LoadLibrary(_libfile)
+# Otherwise, library should have been copied to the lib/ folder
+except:
+ lib = cdll.LoadLibrary(str(pathlib.Path(__file__).parent / "lib" / _libfile))
+
+
+class cError(Structure):
+ _fields_ = [
+ ("level", c_int),
+ ("traceback", c_char_p),
+ ("cause", c_char_p),
+ ]
+
+class cServerLocations(Structure):
+ _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)]
+
+
+class cDiscoveryOrganization(Structure):
+ _fields_ = [
+ ("display_name", c_char_p),
+ ("org_id", c_char_p),
+ ("secure_internet_home", c_char_p),
+ ("keyword_list", c_char_p),
+ ]
+
+
+class cDiscoveryOrganizations(Structure):
+ _fields_ = [
+ ("version", c_ulonglong),
+ ("organizations", POINTER(POINTER(cDiscoveryOrganization))),
+ ("total_organizations", c_size_t),
+ ]
+
+
+class cDiscoveryServer(Structure):
+ _fields_ = [
+ ("authentication_url_template", c_char_p),
+ ("base_url", c_char_p),
+ ("country_code", c_char_p),
+ ("display_name", c_char_p),
+ ("keyword_list", c_char_p),
+ ("public_key_list", POINTER(c_char_p)),
+ ("total_public_keys", c_size_t),
+ ("server_type", c_char_p),
+ ("support_contact", POINTER(c_char_p)),
+ ("total_support_contact", c_size_t),
+ ]
+
+
+class cDiscoveryServers(Structure):
+ _fields_ = [
+ ("version", c_ulonglong),
+ ("servers", POINTER(POINTER(cDiscoveryServer))),
+ ("total_servers", c_size_t),
+ ]
+
+
+class cServerProfile(Structure):
+ _fields_ = [
+ ("identifier", c_char_p),
+ ("display_name", c_char_p),
+ ("default_gateway", c_int),
+ ]
+
+
+class cServerProfiles(Structure):
+ _fields_ = [
+ ("current", c_int),
+ ("profiles", POINTER(POINTER(cServerProfile))),
+ ("total_profiles", c_size_t),
+ ]
+
+
+class cServer(Structure):
+ _fields_ = [
+ ("identifier", c_char_p),
+ ("display_name", c_char_p),
+ ("server_type", c_char_p),
+ ("country_code", c_char_p),
+ ("support_contact", POINTER(c_char_p)),
+ ("total_support_contact", c_size_t),
+ ("profiles", POINTER(cServerProfiles)),
+ ("expire_time", c_ulonglong),
+ ]
+
+
+class cServers(Structure):
+ _fields_ = [
+ ("custom_servers", POINTER(POINTER(cServer))),
+ ("total_custom", c_size_t),
+ ("institute_servers", POINTER(POINTER(cServer))),
+ ("total_institute", c_size_t),
+ ("secure_internet", POINTER(cServer)),
+ ]
+
+
+class DataError(Structure):
+ _fields_ = [("data", c_void_p), ("error", c_void_p)]
+
+
+class ConfigError(Structure):
+ _fields_ = [("config", c_void_p), ("config_type", c_void_p), ("error", c_void_p)]
+
+
+VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p)
+
+# Exposed functions
+# We have to use c_void_p instead of c_char_p to free it properly
+# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error
+lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [
+ c_char_p
+], c_void_p
+lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [
+ c_char_p,
+ c_char_p,
+], c_void_p
+lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [
+ c_char_p,
+ c_char_p,
+], c_void_p
+lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [
+ c_char_p,
+ c_char_p,
+ c_int,
+], ConfigError
+lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [
+ c_char_p,
+ c_char_p,
+ c_int,
+], ConfigError
+lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [
+ c_char_p,
+ c_char_p,
+ c_int,
+], ConfigError
+lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None
+lib.Register.argtypes, lib.Register.restype = [
+ c_char_p,
+ c_char_p,
+ VPNStateChange,
+ c_int,
+], c_void_p
+lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [
+ c_char_p
+], DataError
+lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError
+lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None
+lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p
+lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p
+lib.ChangeSecureLocation.argtypes, lib.ChangeSecureLocation.restype = [
+ c_char_p
+], c_void_p
+lib.SetSecureLocation.argtypes, lib.SetSecureLocation.restype = [
+ c_char_p,
+ c_char_p,
+], c_void_p
+lib.SetConnected.argtypes, lib.SetConnected.restype = [c_char_p], c_void_p
+lib.SetDisconnecting.argtypes, lib.SetDisconnecting.restype = [c_char_p], c_void_p
+lib.SetConnecting.argtypes, lib.SetConnecting.restype = [c_char_p], c_void_p
+lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c_void_p
+lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p
+lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int
+lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p
+lib.FreeProfiles.argtypes, lib.FreeProfiles.restype = [c_void_p], None
+lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None
+lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
+lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [
+ c_void_p
+], None
+lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None
+lib.FreeError.argtypes, lib.FreeError.restype = [c_void_p], None
+lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None
+lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None
+lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int
+lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], DataError
+
+
+def encode_args(args, types):
+ for arg, t in zip(args, types):
+ # c_char_p needs the str to be encoded to bytes
+ if t is c_char_p:
+ arg = arg.encode("utf-8")
+ yield arg
+
+
+def decode_res(t):
+ return decode_map.get(t, lambda x: x)
+
+
+def get_ptr_string(ptr: c_void_p) -> str:
+ if ptr:
+ string = cast(ptr, c_char_p).value
+ lib.FreeString(ptr)
+ if string:
+ return string.decode("utf-8")
+ return ""
+
+
+def get_ptr_list_strings(
+ strings: POINTER(c_char_p), total_strings: c_size_t
+) -> List[str]:
+ if strings:
+ strings_list = []
+ for i in range(total_strings):
+ strings_list.append(strings[i].decode("utf-8"))
+ return strings_list
+ return []
+
+def get_error(ptr: c_void_p) -> Optional[WrappedError]:
+ if not ptr:
+ return None
+ err = cast(ptr, POINTER(cError)).contents
+ wrapped = WrappedError(err.traceback.decode("utf-8"), err.cause.decode("utf-8"), ErrorLevel(err.level))
+ lib.FreeError(ptr)
+ return wrapped
+
+def get_config_error(config_error: ConfigError) -> Tuple[str, str, Optional[WrappedError]]:
+ config = get_ptr_string(config_error.config)
+ config_type = get_ptr_string(config_error.config_type)
+ err = get_error(config_error.error)
+ return config, config_type, err
+
+def get_data_error(data_error: DataError, data_conv=get_ptr_string) -> Tuple[str, Optional[WrappedError]]:
+ data = data_conv(data_error.data)
+ error = get_error(data_error.error)
+ return data, error
+
+
+def get_bool(boolInt: c_int) -> bool:
+ return boolInt == 1
+
+
+decode_map = {
+ c_int: get_bool,
+ c_void_p: get_error,
+ DataError: get_data_error,
+ ConfigError: get_config_error,
+}
diff --git a/wrappers/python/eduvpn_common/discovery.py b/wrappers/python/eduvpn_common/discovery.py
new file mode 100644
index 0000000..68741bc
--- /dev/null
+++ b/wrappers/python/eduvpn_common/discovery.py
@@ -0,0 +1,126 @@
+from eduvpn_common import lib, cDiscoveryOrganizations, cDiscoveryServers, get_ptr_list_strings
+from ctypes import cast, POINTER
+
+
+class DiscoOrganization:
+ def __init__(self, display_name, org_id, secure_internet_home, keyword_list):
+ self.display_name = display_name
+ self.org_id = org_id
+ self.secure_internet_home = secure_internet_home
+ self.keyword_list = keyword_list
+
+ def __str__(self):
+ return self.display_name
+
+
+class DiscoOrganizations:
+ def __init__(self, version, organizations):
+ self.version = version
+ self.organizations = organizations
+
+
+class DiscoServer:
+ def __init__(
+ self,
+ authentication_url_template,
+ base_url,
+ country_code,
+ display_name,
+ keyword_list,
+ public_keys,
+ server_type,
+ support_contacts,
+ ):
+ self.authentication_url_template = authentication_url_template
+ self.base_url = base_url
+ self.country_code = country_code
+ self.display_name = display_name
+ self.keyword_list = keyword_list
+ self.public_keys = public_keys
+ self.server_type = server_type
+ self.support_contacts = support_contacts
+
+ def __str__(self):
+ return self.display_name
+
+
+class DiscoServers:
+ def __init__(self, version, servers):
+ self.version = version
+ self.servers = servers
+
+
+def get_disco_organization(ptr):
+ if not ptr:
+ return None
+
+ current_organization = ptr.contents
+ display_name = current_organization.display_name.decode("utf-8")
+ org_id = current_organization.org_id.decode("utf-8")
+ secure_internet_home = current_organization.secure_internet_home.decode("utf-8")
+ keyword_list = current_organization.keyword_list.decode("utf-8")
+ return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list)
+
+
+def get_disco_server(ptr):
+ if not ptr:
+ return None
+
+ current_server = ptr.contents
+ authentication_url_template = current_server.authentication_url_template.decode(
+ "utf-8"
+ )
+ base_url = current_server.base_url.decode("utf-8")
+ country_code = current_server.country_code.decode("utf-8")
+ display_name = current_server.display_name.decode("utf-8")
+ keyword_list = current_server.keyword_list.decode("utf-8")
+ public_keys = get_ptr_list_strings(
+ current_server.public_key_list, current_server.total_public_keys
+ )
+ server_type = current_server.server_type.decode("utf-8")
+ support_contacts = get_ptr_list_strings(
+ current_server.support_contact, current_server.total_support_contact
+ )
+ return DiscoServer(
+ authentication_url_template,
+ base_url,
+ country_code,
+ display_name,
+ keyword_list,
+ public_keys,
+ server_type,
+ support_contacts,
+ )
+
+
+def get_disco_servers(ptr):
+ if ptr:
+ svrs = cast(ptr, POINTER(cDiscoveryServers)).contents
+
+ servers = []
+
+ if svrs.servers:
+ for i in range(svrs.total_servers):
+ current = get_disco_server(svrs.servers[i])
+
+ if current is None:
+ continue
+ servers.append(current)
+ lib.FreeDiscoServers(ptr)
+ return DiscoServers(svrs.version, servers)
+ return None
+
+
+def get_disco_organizations(ptr):
+ if ptr:
+ orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents
+ organizations = []
+ if orgs.organizations:
+ for i in range(orgs.total_organizations):
+ current = get_disco_organization(orgs.organizations[i])
+ if current is None:
+ continue
+ organizations.append(current)
+ lib.FreeDiscoOrganizations(ptr)
+ return DiscoOrganizations(orgs.version, organizations)
+ return None
diff --git a/wrappers/python/eduvpn_common/error.py b/wrappers/python/eduvpn_common/error.py
new file mode 100644
index 0000000..50298bb
--- /dev/null
+++ b/wrappers/python/eduvpn_common/error.py
@@ -0,0 +1,15 @@
+from enum import Enum
+
+class ErrorLevel(Enum):
+ ERR_OTHER = 0
+ ERR_INFO = 1
+ ERR_WARNING = 2
+ ERR_FATAL = 3
+
+class WrappedError(Exception):
+ def __init__(self, traceback: str, cause: str, level: ErrorLevel):
+ super(WrappedError, self).__init__(cause)
+ self.traceback = traceback
+ self.cause = cause
+ self.level = level
+
diff --git a/wrappers/python/eduvpn_common/event.py b/wrappers/python/eduvpn_common/event.py
new file mode 100644
index 0000000..4532bef
--- /dev/null
+++ b/wrappers/python/eduvpn_common/event.py
@@ -0,0 +1,109 @@
+from eduvpn_common import VPNStateChange, get_ptr_string
+from enum import Enum
+from typing import Callable
+from eduvpn_common.state import State, StateType
+from eduvpn_common.server import (
+ get_locations,
+ get_transition_profiles,
+ get_transition_server,
+ get_servers,
+)
+
+
+EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback"
+
+# A state transition decorator for classes
+# To use this, make sure to register the class with `register_class_callbacks`
+def class_state_transition(state: int, state_type: StateType) -> Callable:
+ def wrapper(func):
+ setattr(func, EDUVPN_CALLBACK_PROPERTY, (state, state_type))
+ return func
+
+ return wrapper
+
+
+def convert_data(state: State, data):
+ if not data:
+ return None
+ if state is State.NO_SERVER:
+ return get_servers(data)
+ if state is State.OAUTH_STARTED:
+ return get_ptr_string(data)
+ if state is State.ASK_LOCATION:
+ return get_locations(data)
+ if state is State.ASK_PROFILE:
+ return get_transition_profiles(data)
+ if state in [
+ State.DISCONNECTED,
+ State.DISCONNECTING,
+ State.CONNECTING,
+ State.CONNECTED,
+ ]:
+ return get_transition_server(data)
+
+
+class EventHandler(object):
+ def __init__(self):
+ self.handlers = {}
+
+ def change_class_callbacks(self, cls, add=True) -> None:
+ # Loop over method names
+ for method_name in dir(cls):
+ try:
+ # Get the method
+ method = getattr(cls, method_name)
+ except:
+ # Unable to get a value, go to the next
+ continue
+
+ # If it has a callback defined, add it to the events
+ method_value = getattr(method, EDUVPN_CALLBACK_PROPERTY, None)
+ if method_value:
+ state, state_type = method_value
+
+ if add:
+ self.add_event(state, state_type, method)
+ else:
+ self.remove_event(state, state_type, method)
+
+ def remove_event(self, state: int, state_type: StateType, func: Callable):
+ for key, values in self.handlers.copy().items():
+ if key == (state, state_type):
+ values.remove(func)
+ if not values:
+ del self.handlers[key]
+ else:
+ self.handlers[key] = values
+
+ def add_event(self, state: int, state_type: StateType, func: Callable):
+ if (state, state_type) not in self.handlers:
+ self.handlers[(state, state_type)] = []
+ self.handlers[(state, state_type)].append(func)
+
+ # A decorator for standalone functions
+ def on(self, state: int, state_type: StateType) -> Callable:
+ def wrapped_f(func):
+ self.add_event(state, state_type, func)
+ return func
+
+ return wrapped_f
+
+ def run_state(
+ self, state: int, other_state: int, state_type: StateType, data: str
+ ) -> None:
+ if (state, state_type) not in self.handlers:
+ return
+ for func in self.handlers[(state, state_type)]:
+ func(other_state, data)
+
+ def run(
+ self, old_state: int, new_state: int, data: str, convert: bool = True
+ ) -> None:
+ # First run leave transitions, then enter
+ # The state is done when the wait event finishes
+ converted = data
+ if convert:
+ converted = convert_data(new_state, data)
+ self.run_state(old_state, new_state, StateType.Leave, converted)
+ self.run_state(new_state, old_state, StateType.Enter, converted)
+ self.run_state(new_state, old_state, StateType.Wait, converted)
diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py
new file mode 100644
index 0000000..3875ad9
--- /dev/null
+++ b/wrappers/python/eduvpn_common/main.py
@@ -0,0 +1,253 @@
+from eduvpn_common import lib, VPNStateChange, encode_args, decode_res, get_data_error
+from typing import Optional, Tuple
+import threading
+from eduvpn_common.discovery import get_disco_organizations, get_disco_servers
+from eduvpn_common.event import EventHandler
+from eduvpn_common.state import State, StateType
+from eduvpn_common.server import get_servers
+
+eduvpn_objects = {}
+
+
+def add_as_global_object(eduvpn) -> bool:
+ global eduvpn_objects
+ if eduvpn.name not in eduvpn_objects:
+ eduvpn_objects[eduvpn.name] = eduvpn
+ return True
+ return False
+
+
+def remove_as_global_object(eduvpn):
+ global eduvpn_objects
+ eduvpn_objects.pop(eduvpn.name, None)
+
+
+@VPNStateChange
+def state_callback(name, old_state, new_state, data):
+ name = name.decode()
+ if name not in eduvpn_objects:
+ return
+ eduvpn_objects[name].callback(State(old_state), State(new_state), data)
+
+
+class EduVPN(object):
+ def __init__(self, name: str, config_directory: str):
+ self.event_handler = EventHandler()
+ self.name = name
+ self.config_directory = config_directory
+
+ # Callbacks that need to wait for specific events
+
+ # The ask profile callback needs to wait for the UI thread to select a profile
+ # This is stored in the profile_event
+ self.profile_event: Optional[threading.Event] = None
+ self.location_event: Optional[threading.Event] = None
+
+ @self.event.on(State.ASK_PROFILE, StateType.Wait)
+ def wait_profile_event(old_state: int, profiles: str):
+ if self.profile_event:
+ self.profile_event.wait()
+
+ @self.event.on(State.ASK_LOCATION, StateType.Wait)
+ def wait_location_event(old_state: int, locations: str):
+ if self.location_event:
+ self.location_event.wait()
+
+ def go_function(self, func, *args):
+ # The functions all have at least one arg type which is the name of the client
+ args_gen = encode_args(list(args), func.argtypes[1:])
+ res = func(self.name.encode("utf-8"), *(args_gen))
+ return decode_res(func.restype)(res)
+
+ def go_function_custom_decode(self, func, decode_func, *args):
+ # The functions all have at least one arg type which is the name of the client
+ args_gen = encode_args(list(args), func.argtypes[1:])
+ res = func(self.name.encode("utf-8"), *(args_gen))
+ return decode_func(res)
+
+ def cancel_oauth(self) -> None:
+ cancel_oauth_err = self.go_function(lib.CancelOAuth)
+
+ if cancel_oauth_err:
+ raise cancel_oauth_err
+
+ def deregister(self) -> None:
+ self.go_function(lib.Deregister)
+ remove_as_global_object(self)
+
+ def register(self, debug: bool = False) -> None:
+ if not add_as_global_object(self):
+ raise Exception("Already registered")
+
+ register_err = self.go_function(
+ lib.Register, self.config_directory, state_callback, debug
+ )
+
+ if register_err:
+ raise register_err
+
+ def get_disco_servers(self) -> str:
+ servers, servers_err = self.go_function_custom_decode(
+ lib.GetDiscoServers, decode_func=lambda x: get_data_error(x, get_disco_servers)
+ )
+
+ if servers_err:
+ raise servers_err
+
+ return servers
+
+ def get_disco_organizations(self) -> str:
+ organizations, organizations_err = self.go_function_custom_decode(
+ lib.GetDiscoOrganizations, decode_func=lambda x: get_data_error(x, get_disco_organizations)
+ )
+
+ if organizations_err:
+ raise organizations_err
+
+ return organizations
+
+ def remove_secure_internet(self):
+ remove_err = self.go_function(lib.RemoveSecureInternet)
+
+ if remove_err:
+ raise remove_err
+
+ def remove_institute_access(self, url: str):
+ remove_err = self.go_function(lib.RemoveInstituteAccess, url)
+
+ if remove_err:
+ raise remove_err
+
+ def remove_custom_server(self, url: str):
+ remove_err = self.go_function(lib.RemoveCustomServer, url)
+
+ if remove_err:
+ raise remove_err
+
+ def get_config(self, url: str, func: callable, force_tcp: bool = False):
+ # Because it could be the case that a profile callback is started, store a threading event
+ # In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set
+ # The event is set in self.set_profile
+ self.profile_event = threading.Event()
+
+ config, config_type, config_err = self.go_function(func, url, force_tcp)
+
+ self.profile_event = None
+ self.location_event = None
+
+ if config_err:
+ raise config_err
+
+ return config, config_type
+
+ def get_config_custom_server(
+ self, url: str, force_tcp: bool = False
+ ) -> Tuple[str, str]:
+ return self.get_config(url, lib.GetConfigCustomServer, force_tcp)
+
+ def get_config_institute_access(
+ self, url: str, force_tcp: bool = False
+ ) -> Tuple[str, str]:
+ return self.get_config(url, lib.GetConfigInstituteAccess, force_tcp)
+
+ def get_config_secure_internet(
+ self, url: str, force_tcp: bool = False
+ ) -> Tuple[str, str]:
+ self.location_event = threading.Event()
+ return self.get_config(url, lib.GetConfigSecureInternet, force_tcp)
+
+ def go_back(self) -> None:
+ # Ignore the error
+ self.go_function(lib.GoBack)
+
+ def set_connected(self) -> None:
+ connect_err = self.go_function(lib.SetConnected)
+
+ if connect_err:
+ raise connect_err
+
+ def set_disconnecting(self) -> None:
+ disconnecting_err = self.go_function(lib.SetDisconnecting)
+
+ if disconnecting_err:
+ raise disconnecting_err
+
+ def set_connecting(self) -> None:
+ connecting_err = self.go_function(lib.SetConnecting)
+
+ if connecting_err:
+ raise connecting_err
+
+ def set_disconnected(self, cleanup=True) -> None:
+ disconnect_err = self.go_function(lib.SetDisconnected, cleanup)
+
+ if disconnect_err:
+ raise disconnect_err
+
+ def set_search_server(self) -> None:
+ search_err = self.go_function(lib.SetSearchServer)
+
+ if search_err:
+ raise search_err
+
+ def remove_class_callbacks(self, cls) -> None:
+ self.event_handler.change_class_callbacks(cls, add=False)
+
+ def register_class_callbacks(self, cls) -> None:
+ self.event_handler.change_class_callbacks(cls)
+
+ @property
+ def event(self) -> EventHandler:
+ return self.event_handler
+
+ def callback(self, old_state: State, new_state: State, data) -> None:
+ self.event.run(old_state, new_state, data)
+
+ def set_profile(self, profile_id: str) -> None:
+ # Set the profile id
+ profile_err = self.go_function(lib.SetProfileID, profile_id)
+
+ # If there is a profile event, set it so that the wait callback finishes
+ # And so that the Go code can move to the next state
+ if self.profile_event:
+ self.profile_event.set()
+
+ if profile_err:
+ raise profile_err
+
+ def change_secure_location(self) -> None:
+ # Set the location by country code
+ self.location_event = threading.Event()
+ location_err = self.go_function(lib.ChangeSecureLocation)
+
+ if location_err:
+ raise location_err
+
+ def set_secure_location(self, country_code: str) -> None:
+ # Set the location by country code
+ location_err = self.go_function(lib.SetSecureLocation, country_code)
+
+ # If there is a location event, set it so that the wait callback finishes
+ # And so that the Go code can move to the next state
+ if self.location_event:
+ self.location_event.set()
+
+ if location_err:
+ raise location_err
+
+ def renew_session(self) -> None:
+ renew_err = self.go_function(lib.RenewSession)
+
+ if renew_err:
+ raise renew_err
+
+ def should_renew_button(self) -> bool:
+ return self.go_function(lib.ShouldRenewButton)
+
+ def in_fsm_state(self, state_id: State) -> bool:
+ return self.go_function(lib.InFSMState, state_id)
+
+ def get_saved_servers(self) -> str:
+ return self.go_function_custom_decode(
+ lib.GetSavedServers, decode_func=lambda x: get_data_error(x, get_servers)
+ )
diff --git a/wrappers/python/eduvpn_common/server.py b/wrappers/python/eduvpn_common/server.py
new file mode 100644
index 0000000..470f704
--- /dev/null
+++ b/wrappers/python/eduvpn_common/server.py
@@ -0,0 +1,175 @@
+from eduvpn_common import lib, cServer, cServers, cServerLocations, cServerProfiles
+from ctypes import cast, POINTER, c_char_p
+from datetime import datetime
+
+
+class Profile:
+ def __init__(self, identifier, display_name, default_gateway: bool):
+ self.identifier = identifier
+ self.display_name = display_name
+ self.default_gateway = default_gateway
+
+ def __str__(self):
+ return self.display_name
+
+
+class Profiles:
+ def __init__(self, profiles, current):
+ self.profiles = profiles
+ self.current_index = current
+
+ @property
+ def current(self):
+ if self.current_index < len(self.profiles):
+ return self.profiles[self.current_index]
+ return None
+
+
+class Server:
+ def __init__(self, url, display_name, profiles=None, expire_time=0):
+ self.url = url
+ self.display_name = display_name
+ self.profiles = profiles
+ self.current_profile = None
+ self.expire_time = datetime.fromtimestamp(expire_time)
+
+ def __str__(self):
+ return self.display_name
+
+
+class InstituteServer(Server):
+ def __init__(self, url, display_name, support_contact, profiles, expire_time):
+ super().__init__(url, display_name, profiles, expire_time)
+ self.support_contact = support_contact
+
+
+class SecureInternetServer(Server):
+ def __init__(
+ self,
+ org_id,
+ display_name,
+ support_contact,
+ profiles,
+ expire_time,
+ country_code,
+ ):
+ super().__init__(org_id, display_name, profiles, expire_time)
+ self.org_id = org_id
+ self.support_contact = support_contact
+ self.country_code = country_code
+
+
+def get_type_for_str(type_str: str):
+ if type_str == "secure_internet":
+ return SecureInternetServer
+ if type_str == "custom_server":
+ return Server
+ return InstituteServer
+
+
+def get_profiles(ptr):
+ if not ptr:
+ return []
+ profiles = []
+ _profiles = ptr.contents
+ current_profile = _profiles.current
+ if not _profiles.profiles:
+ return []
+ for i in range(_profiles.total_profiles):
+ if not _profiles.profiles[i]:
+ continue
+ profile = _profiles.profiles[i].contents
+ profiles.append(
+ Profile(
+ profile.identifier.decode("utf-8"),
+ profile.display_name.decode("utf-8"),
+ profile.default_gateway == 1,
+ )
+ )
+ return Profiles(profiles, current_profile)
+
+
+def get_server(ptr, _type=None):
+ if not ptr:
+ return None
+
+ current_server = ptr.contents
+ if _type is None:
+ _type = get_type_for_str(current_server.server_type.decode("utf-8"))
+
+ identifier = current_server.identifier.decode("utf-8")
+ display_name = current_server.display_name.decode("utf-8")
+
+ if _type is not Server:
+ support_contact = []
+ for i in range(current_server.total_support_contact):
+ support_contact.append(current_server.support_contact[i].decode("utf-8"))
+ profiles = get_profiles(current_server.profiles)
+ if _type is SecureInternetServer:
+ return SecureInternetServer(
+ identifier,
+ display_name,
+ support_contact,
+ profiles,
+ current_server.expire_time,
+ current_server.country_code.decode("utf-8"),
+ )
+ if _type is InstituteServer:
+ return InstituteServer(
+ identifier,
+ display_name,
+ support_contact,
+ profiles,
+ current_server.expire_time,
+ )
+ return Server(identifier, display_name, profiles, current_server.expire_time)
+
+
+def get_transition_server(ptr):
+ server = get_server(cast(ptr, POINTER(cServer)))
+ lib.FreeServer(ptr)
+ return server
+
+
+def get_transition_profiles(ptr):
+ profiles = get_profiles(cast(ptr, POINTER(cServerProfiles)))
+ lib.FreeProfiles(ptr)
+ return profiles
+
+
+def get_servers(ptr):
+ if ptr:
+ returned = []
+ servers = cast(ptr, POINTER(cServers)).contents
+ if servers.custom_servers:
+ for i in range(servers.total_custom):
+ current = get_server(servers.custom_servers[i], Server)
+ if current is None:
+ continue
+ returned.append(current)
+
+ if servers.institute_servers:
+ for i in range(servers.total_institute):
+ current = get_server(servers.institute_servers[i], InstituteServer)
+ if current is None:
+ continue
+ returned.append(current)
+
+ if servers.secure_internet:
+ current = get_server(servers.secure_internet, SecureInternetServer)
+ if current is not None:
+ returned.append(current)
+ lib.FreeServers(ptr)
+ return returned
+ return None
+
+
+def get_locations(ptr):
+ if ptr:
+ locations = cast(ptr, POINTER(cServerLocations)).contents
+ location_list = []
+ for i in range(locations.total_locations):
+ location_list.append(locations.locations[i].decode("utf-8"))
+ lib.FreeSecureLocations(ptr)
+ return location_list
+ return None
diff --git a/wrappers/python/eduvpn_common/state.py b/wrappers/python/eduvpn_common/state.py
new file mode 100644
index 0000000..5af004f
--- /dev/null
+++ b/wrappers/python/eduvpn_common/state.py
@@ -0,0 +1,24 @@
+from enum import IntEnum
+
+
+class StateType(IntEnum):
+ Enter = 1
+ Leave = 2
+ Wait = 3
+
+
+class State(IntEnum):
+ DEREGISTERED = 0
+ NO_SERVER = 1
+ ASK_LOCATION = 2
+ SEARCH_SERVER = 3
+ LOADING_SERVER = 4
+ CHOSEN_SERVER = 5
+ OAUTH_STARTED = 6
+ AUTHORIZED = 7
+ REQUEST_CONFIG = 8
+ ASK_PROFILE = 9
+ DISCONNECTED = 10
+ DISCONNECTING = 11
+ CONNECTING = 12
+ CONNECTED = 13