summaryrefslogtreecommitdiff
path: root/wrappers
diff options
context:
space:
mode:
Diffstat (limited to 'wrappers')
-rw-r--r--wrappers/python/main.py92
-rw-r--r--wrappers/python/src/__init__.py36
-rw-r--r--wrappers/python/src/main.py189
-rw-r--r--wrappers/python/tests.py29
4 files changed, 258 insertions, 88 deletions
diff --git a/wrappers/python/main.py b/wrappers/python/main.py
index 1c1afd7..c887fed 100644
--- a/wrappers/python/main.py
+++ b/wrappers/python/main.py
@@ -1,32 +1,90 @@
import eduvpncommon.main as eduvpn
import webbrowser
+import json
+# Asks the user for a profile index
+# It loops up until a valid input is given
+def ask_profile_input(total: int) -> int:
+ profile_index = None
-_eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "configs")
+ while profile_index is None:
+ try:
+ profile_index = int(
+ input("Please select a profile by inputting a number (e.g. 1): ")
+ )
+ if (profile_index > total) or (profile_index < 1):
+ print("Invalid profile range")
+ profile_index = None
+ except ValueError:
+ print("Please enter a valid input")
+ # The profile is one based, move to zero based input
+ return profile_index - 1
-@_eduvpn.event.on("OAuth_Started", eduvpn.StateType.Enter)
-def oauth_initialized(url):
- print(f"Got OAUTH url {url}")
- webbrowser.open(url)
+# Sets up the callbacks using the provided class
+def setup_callbacks(_eduvpn: eduvpn.EduVPN) -> None:
+ # The callback that starst OAuth
+ # It needs to open the URL in the web browser
+ @_eduvpn.event.on("OAuth_Started", eduvpn.StateType.Enter)
+ def oauth_initialized(old_state: str, url: str) -> None:
+ print(f"Got OAuth URL {url}, old state: {old_state}")
+ webbrowser.open(url)
+ # The callback which asks the user for a profile
+ @_eduvpn.event.on("Ask_Profile", eduvpn.StateType.Enter)
+ def ask_profile(old_state: str, profiles: str):
+ print("Multiple profiles found, you need to select a profile, old state: {old_state}")
-@_eduvpn.event.on("Ask_Profile", eduvpn.StateType.Enter)
-def ask_profile(profiles):
- print("ASK PROFILE CB", profiles)
- _eduvpn.set_profile("prefer-openvpn")
+ # Parse the profiles as JSON
+ data = json.loads(profiles)
+ # Get a lits of profiles
+ profile_strings = [x["profile_id"] for x in data["info"]["profile_list"]]
+ total_profiles = len(profile_strings)
-success = _eduvpn.register(debug=True)
+ # Create a list of the strings to standard output
+ for idx, profile in enumerate(profile_strings):
+ print(f"{idx+1}. {profile}")
-if not success:
- print("failed to register")
+ # Get the profile index from the user
+ profile_index = ask_profile_input(total_profiles)
-print(_eduvpn.get_disco())
+ # Set the profile with the index
+ _eduvpn.set_profile(profile_strings[profile_index])
-config, error = _eduvpn.get_config_institute_access("https://eduvpn.jwijenbergh.com")
-if error:
- print("Got connect error", error)
+# The main entry point
+if __name__ == "__main__":
+ _eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "configs")
+ setup_callbacks(_eduvpn)
-print(config)
+ # Register with the eduVPN-common library
+ try:
+ _eduvpn.register(debug=True)
+ except Exception as e:
+ print("Failed registering:", e)
+
+ server = input(
+ "Which Institute Access server do you want to connect to? (e.g. https://eduvpn.example.com): "
+ )
+
+ # Ensure we have a valid http prefix
+ if not server.startswith("http"):
+ # https by default
+ server = "https://" + server
+
+ # Get a Wireguard/OpenVPN config
+ try:
+ config, config_type = _eduvpn.get_config_institute_access(server)
+ except Exception as e:
+ print("Failed to connect:", e)
+ print(f"Got a config with type: {config_type} and contents:\n{config}")
+
+ # Set the internal FSM state to connected
+ try:
+ _eduvpn.set_connected()
+ except Exception as e:
+ print("Failed to set connected:", e)
+
+ # Save and exit
+ _eduvpn.deregister()
diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py
index c028f09..c96a1b2 100644
--- a/wrappers/python/src/__init__.py
+++ b/wrappers/python/src/__init__.py
@@ -2,6 +2,7 @@ from ctypes import *
from collections import defaultdict
import pathlib
import platform
+from typing import Tuple
_lib_prefixes = defaultdict(
lambda: "lib",
@@ -30,15 +31,31 @@ class DataError(Structure):
_fields_ = [("data", c_void_p), ("error", c_void_p)]
+class MultipleDataError(Structure):
+ _fields_ = [("data", c_void_p), ("other_data", c_void_p), ("error", c_void_p)]
+
+
VPNStateChange = CFUNCTYPE(None, c_char_p, c_char_p, c_char_p)
# Exposed functions
# We have to use c_void_p instead of c_char_p to free it properly
# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error
-lib.GetConnectConfig.argtypes, lib.GetConnectConfig.restype = [c_char_p, c_char_p, c_int, c_int], DataError
+lib.GetConnectConfig.argtypes, lib.GetConnectConfig.restype = [
+ c_char_p,
+ c_char_p,
+ c_int,
+ c_int,
+], MultipleDataError
lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], c_void_p
-lib.Register.argtypes, lib.Register.restype = [c_char_p, c_char_p, VPNStateChange, c_int], c_void_p
-lib.GetOrganizationsList.argtypes, lib.GetOrganizationsList.restype = [c_char_p], DataError
+lib.Register.argtypes, lib.Register.restype = [
+ c_char_p,
+ c_char_p,
+ VPNStateChange,
+ c_int,
+], c_void_p
+lib.GetOrganizationsList.argtypes, lib.GetOrganizationsList.restype = [
+ c_char_p
+], DataError
lib.GetServersList.argtypes, lib.GetServersList.restype = [c_char_p], DataError
lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p
lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p
@@ -47,7 +64,7 @@ lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p], c_void_p
lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
-def GetPtrString(ptr):
+def GetPtrString(ptr: c_void_p) -> str:
if ptr:
string = cast(ptr, c_char_p).value
lib.FreeString(ptr)
@@ -56,7 +73,16 @@ def GetPtrString(ptr):
return ""
-def GetDataError(data_error):
+def GetDataError(data_error: DataError) -> Tuple[str, str]:
data = GetPtrString(data_error.data)
error = GetPtrString(data_error.error)
return data, error
+
+
+def GetMultipleDataError(
+ multiple_data_error: MultipleDataError,
+) -> Tuple[str, str, str]:
+ data = GetPtrString(multiple_data_error.data)
+ other_data = GetPtrString(multiple_data_error.other_data)
+ error = GetPtrString(multiple_data_error.error)
+ return data, other_data, error
diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py
index 2b346e3..dd8f36a 100644
--- a/wrappers/python/src/main.py
+++ b/wrappers/python/src/main.py
@@ -1,6 +1,7 @@
-from . import lib, VPNStateChange, GetDataError, GetPtrString
+from . import lib, VPNStateChange, GetDataError, GetMultipleDataError, GetPtrString
from ctypes import *
from enum import Enum
+from typing import Callable, Optional, Tuple
class StateType(Enum):
@@ -8,47 +9,94 @@ class StateType(Enum):
Leave = 2
+class EventHandler(object):
+ def __init__(self):
+ self.handlers = {}
+
+ def on(self, state: str, state_type: StateType) -> Callable:
+ def wrapped_f(func):
+ if (state, state_type) not in self.handlers:
+ self.handlers[(state, state_type)] = []
+ self.handlers[(state, state_type)].append(func)
+ return func
+
+ return wrapped_f
+
+ def run_state(self, state: str, other_state: str, state_type: StateType, data: str) -> None:
+ if (state, state_type) not in self.handlers:
+ return
+ for func in self.handlers[(state, state_type)]:
+ func(other_state, data)
+
+ def run(self, old_state: str, new_state: str, data: str) -> None:
+ if old_state == new_state:
+ return
+ self.run_state(old_state, new_state, StateType.Leave, data)
+ self.run_state(new_state, old_state, StateType.Enter, data)
+
+
# Registers the python app with the Go code
# name: The name of the app to be registered
# state_callback: The callback to trigger whenever a state is changed
-def Register(name, config_directory, state_callback, debug):
+def Register(
+ name: str, config_directory: str, state_callback: Optional[Callable], debug: bool
+) -> str:
+ if not state_callback:
+ return "No callback provided"
name_bytes = name.encode("utf-8")
dir_bytes = config_directory.encode("utf-8")
ptr_err = lib.Register(name_bytes, dir_bytes, state_callback, debug)
err_string = GetPtrString(ptr_err)
return err_string
-def CancelOAuth(name):
+
+def CancelOAuth(name: str) -> str:
name_bytes = name.encode("utf-8")
ptr_err = lib.CancelOAuth(name_bytes)
err_string = GetPtrString(ptr_err)
return err_string
-def Deregister(name):
+
+def Deregister(name: str) -> str:
name_bytes = name.encode("utf-8")
ptr_err = lib.Deregister(name_bytes)
err_string = GetPtrString(ptr_err)
return err_string
-def GetDiscoServers(name):
+
+def GetDiscoServers(name: str) -> Tuple[str, str]:
+ name_bytes = name.encode("utf-8")
+ servers, servers_err = GetDataError(lib.GetServersList(name_bytes))
+ return servers, servers_err
+
+
+def GetDiscoOrganizations(name: str) -> Tuple[str, str]:
name_bytes = name.encode("utf-8")
- servers, serversErr = GetDataError(lib.GetServersList(name_bytes))
- organizations, organizationsErr = GetDataError(lib.GetOrganizationsList(name_bytes))
- return servers, serversErr, organizations, organizationsErr
+ organizations, organizations_err = GetDataError(
+ lib.GetOrganizationsList(name_bytes)
+ )
+ return organizations, organizations_err
+
-def GetConnectConfig(name, url, is_secure_internet, force_tcp):
+def GetConnectConfig(
+ name: str, url: str, is_secure_internet: bool, force_tcp: bool
+) -> Tuple[str, str, str]:
name_bytes = name.encode("utf-8")
url_bytes = url.encode("utf-8")
- data_error = lib.GetConnectConfig(name_bytes, url_bytes, is_secure_internet, force_tcp)
- return GetDataError(data_error)
+ multiple_data_error = lib.GetConnectConfig(
+ name_bytes, url_bytes, is_secure_internet, force_tcp
+ )
+ return GetMultipleDataError(multiple_data_error)
+
-def SetConnected(name):
+def SetConnected(name: str) -> str:
name_bytes = name.encode("utf-8")
ptr_err = lib.SetConnected(name_bytes)
err_string = GetPtrString(ptr_err)
return err_string
-def SetDisconnected(name):
+
+def SetDisconnected(name: str) -> str:
name_bytes = name.encode("utf-8")
ptr_err = lib.SetDisconnected(name_bytes)
err_string = GetPtrString(ptr_err)
@@ -68,7 +116,7 @@ def register_callback(eduvpn):
)
-def SetProfileID(name, profile_id) -> str:
+def SetProfileID(name: str, profile_id: str) -> str:
name_bytes = name.encode("utf-8")
profile_bytes = profile_id.encode("utf-8")
error_string = lib.SetProfileID(name_bytes, profile_bytes)
@@ -76,68 +124,93 @@ def SetProfileID(name, profile_id) -> str:
class EduVPN(object):
- def __init__(self, name, config_directory):
+ def __init__(self, name: str, config_directory: str):
self.event_handler = EventHandler()
self.name = name
self.config_directory = config_directory
register_callback(self)
- def cancel_oauth(self) -> str:
- return CancelOAuth(self.name)
+ def cancel_oauth(self) -> None:
+ cancel_oauth_err = CancelOAuth(self.name)
- def deregister(self) -> str:
- return Deregister(self.name)
+ if cancel_oauth_err:
+ raise Exception(cancel_oauth_err)
- def register(self, debug=False) -> bool:
- return Register(self.name, self.config_directory, callback_function, debug) == ""
+ def deregister(self) -> None:
+ deregister_err = Deregister(self.name)
- def get_disco(self):
- return GetDiscoServers(self.name)
+ if deregister_err:
+ raise Exception(deregister_err)
- def get_config_institute_access(self, url, force_tcp=False):
- return GetConnectConfig(self.name, url, False, force_tcp)
+ def register(self, debug: bool = False) -> None:
+ register_err = Register(
+ self.name, self.config_directory, callback_function, debug
+ )
- def get_config_secure_internet(self, url, force_tcp=False):
- return GetConnectConfig(self.name, url, True, force_tcp)
+ if register_err:
+ raise Exception(register_err)
- def set_disconnected(self):
- return SetDisconnected(self.name)
+ def get_disco_servers(self) -> str:
+ servers, servers_err = GetDiscoServers(self.name)
- def set_connected(self):
- return SetConnected(self.name)
+ if servers_err:
+ raise Exception(servers_err)
- @property
- def event(self):
- return self.event_handler
+ return servers
- def callback(self, old_state, new_state, data):
- self.event.run(old_state, new_state, data)
+ def get_disco_organizations(self) -> str:
+ organizations, organizations_err = GetDiscoOrganizations(self.name)
- def set_profile(self, profile_id) -> str:
- return SetProfileID(self.name, profile_id)
+ if organizations_err:
+ raise Exception(organizations_err)
+ return organizations
-class EventHandler(object):
- def __init__(self):
- self.handlers = {}
+ def get_config_institute_access(
+ self, url: str, force_tcp: bool = False
+ ) -> Tuple[str, str]:
+ config, config_type, config_err = GetConnectConfig(
+ self.name, url, False, force_tcp
+ )
- def on(self, state, state_type):
- def wrapped_f(func):
- if (state, state_type) not in self.handlers:
- self.handlers[(state, state_type)] = []
- self.handlers[(state, state_type)].append(func)
- return func
+ if config_err:
+ raise Exception(config_err)
- return wrapped_f
+ return config, config_type
- def run_state(self, state, state_type, data):
- if (state, state_type) not in self.handlers:
- return
- for func in self.handlers[(state, state_type)]:
- func(data)
+ def get_config_secure_internet(
+ self, url: str, force_tcp: bool = False
+ ) -> Tuple[str, str]:
+ config, config_type, config_err = GetConnectConfig(
+ self.name, url, True, force_tcp
+ )
- def run(self, old_state, new_state, data):
- if old_state == new_state:
- return
- self.run_state(old_state, StateType.Leave, data)
- self.run_state(new_state, StateType.Enter, data)
+ if config_err:
+ raise Exception(config_err)
+
+ return config, config_type
+
+ def set_disconnected(self) -> None:
+ disconnect_err = SetDisconnected(self.name)
+
+ if disconnect_err:
+ raise Exception(disconnect_err)
+
+ def set_connected(self) -> None:
+ connect_err = SetConnected(self.name)
+
+ if connect_err:
+ raise Exception(connect_err)
+
+ @property
+ def event(self) -> EventHandler:
+ return self.event_handler
+
+ def callback(self, old_state: str, new_state: str, data: str) -> None:
+ self.event.run(old_state, new_state, data)
+
+ def set_profile(self, profile_id: str) -> None:
+ profile_err = SetProfileID(self.name, profile_id)
+
+ if profile_err:
+ raise Exception(profile_err)
diff --git a/wrappers/python/tests.py b/wrappers/python/tests.py
index f006646..60ed79e 100644
--- a/wrappers/python/tests.py
+++ b/wrappers/python/tests.py
@@ -13,20 +13,33 @@ from selenium_eduvpn import login_eduvpn
class ConfigTests(unittest.TestCase):
def testConfig(self):
- self._eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "testconfigs")
- assert self._eduvpn.register()
- @self._eduvpn.event.on("OAuth_Started", eduvpn.StateType.Enter)
- def oauth_initialized(url):
+ _eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "testconfigs")
+ # This can throw an exception
+ _eduvpn.register()
+ @_eduvpn.event.on("OAuth_Started", eduvpn.StateType.Enter)
+ def oauth_initialized(old_state, url):
login_eduvpn(url)
server_uri = os.getenv("SERVER_URI")
if not server_uri:
self.fail("No SERVER_URI environment variable given")
- config, error = self._eduvpn.get_config_institute_access(server_uri)
-
- if error != "":
- self.fail(f"Got error: {error} when connecting to {server_uri}")
+ # This can throw an exception
+ _eduvpn.get_config_institute_access(server_uri)
+
+ # Deregister
+ _eduvpn.deregister()
+
+ def testDoubleRegister(self):
+ _eduvpn = eduvpn.EduVPN("org.eduvpn.app.linux", "testconfigs")
+ # This can throw an exception
+ _eduvpn.register()
+ # This should throw
+ try:
+ _eduvpn.register()
+ except Exception as e:
+ return
+ self.fail("No exception thrown on second register")
if __name__ == "__main__":
unittest.main()