summaryrefslogtreecommitdiff
path: root/wrappers
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-14 13:56:49 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-14 13:56:49 +0200
commitda83f54606c9c1d2786d87074ee17ed972d2e1b2 (patch)
tree0be57934f9f467c87576abb0b457fb54b2d25d52 /wrappers
parentfd34e72da8c604517050ada7e883ba982829d985 (diff)
Refactor: Return without json
Diffstat (limited to 'wrappers')
-rw-r--r--wrappers/python/main.py57
-rw-r--r--wrappers/python/src/__init__.py68
-rw-r--r--wrappers/python/src/discovery.py43
-rw-r--r--wrappers/python/src/event.py21
-rw-r--r--wrappers/python/src/main.py25
-rw-r--r--wrappers/python/src/server.py158
-rw-r--r--wrappers/python/tests.py2
7 files changed, 331 insertions, 43 deletions
diff --git a/wrappers/python/main.py b/wrappers/python/main.py
index 1ab29cc..0bd2502 100644
--- a/wrappers/python/main.py
+++ b/wrappers/python/main.py
@@ -3,6 +3,8 @@ from eduvpn_common.state import State, StateType
import webbrowser
import json
import sys
+import time
+from typing import List
# Asks the user for a profile index
# It loops up until a valid input is given
@@ -27,6 +29,11 @@ def ask_profile_input(total: int) -> int:
# Sets up the callbacks using the provided class
def setup_callbacks(_eduvpn: eduvpn.EduVPN) -> None:
# The callback that starst OAuth
+ @_eduvpn.event.on(State.NO_SERVER, StateType.Enter)
+ def no_server(old_state: str, servers) -> None:
+ for server in servers:
+ print(type(server))
+ print(server)
# It needs to open the URL in the web browser
@_eduvpn.event.on(State.OAUTH_STARTED, StateType.Enter)
def oauth_initialized(old_state: str, url: str) -> None:
@@ -34,31 +41,30 @@ def setup_callbacks(_eduvpn: eduvpn.EduVPN) -> None:
webbrowser.open(url)
@_eduvpn.event.on(State.ASK_LOCATION, StateType.Enter)
- def ask_location(old_state: str, locations: str):
- print("Locations: ", locations)
- _eduvpn.set_secure_location("NL")
+ def ask_location(old_state: str, locations: List[str]):
+ _eduvpn.set_secure_location(locations[1])
- # The callback which asks the user for a profile
- @_eduvpn.event.on(State.ASK_PROFILE, StateType.Enter)
- def ask_profile(old_state: str, profiles: str):
- print("Multiple profiles found, you need to select a profile:")
+ ## The callback which asks the user for a profile
+ #@_eduvpn.event.on(State.ASK_PROFILE, StateType.Enter)
+ #def ask_profile(old_state: str, profiles: str):
+ # print("Multiple profiles found, you need to select a profile:")
- # Parse the profiles as JSON
- data = json.loads(profiles)
+ # # Parse the profiles as JSON
+ # data = json.loads(profiles)
- # Get a lits of profiles
- profile_strings = [x["profile_id"] for x in data["info"]["profile_list"]]
- total_profiles = len(profile_strings)
+ # # Get a lits of profiles
+ # profile_strings = [x["profile_id"] for x in data["info"]["profile_list"]]
+ # total_profiles = len(profile_strings)
- # Create a list of the strings to standard output
- for idx, profile in enumerate(profile_strings):
- print(f"{idx+1}. {profile}")
+ # # Create a list of the strings to standard output
+ # for idx, profile in enumerate(profile_strings):
+ # print(f"{idx+1}. {profile}")
- # Get the profile index from the user
- profile_index = ask_profile_input(total_profiles)
+ # # Get the profile index from the user
+ # profile_index = ask_profile_input(total_profiles)
- # Set the profile with the index
- _eduvpn.set_profile(profile_strings[profile_index])
+ # # Set the profile with the index
+ # _eduvpn.set_profile(profile_strings[profile_index])
# The main entry point
@@ -72,18 +78,13 @@ if __name__ == "__main__":
except Exception as e:
print("Failed registering:", e)
- server = input(
- "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): "
- )
-
- # Ensure we have a valid http prefix
- if not server.startswith("http"):
- # https by default
- server = "https://" + server
+ #server = input(
+ # "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): "
+ #)
# Get a Wireguard/OpenVPN config
try:
- config, config_type = _eduvpn.get_config_custom_server(server)
+ config, config_type = _eduvpn.get_config_secure_internet("https://idp.geant.org")
print(f"Got a config with type: {config_type} and contents:\n{config}")
except Exception as e:
print("Failed to connect:", e)
diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py
index 5b63651..db5484f 100644
--- a/wrappers/python/src/__init__.py
+++ b/wrappers/python/src/__init__.py
@@ -40,12 +40,66 @@ class ErrorLevel(Enum):
ERR_OTHER = 0
ERR_INFO = 1
+class cServerLocations(Structure):
+ _fields_ = [
+ ("locations", POINTER(c_char_p)),
+ ("total_locations", c_size_t)
+ ]
+
+class cDiscoveryOrganization(Structure):
+ _fields_ = [
+ ("display_name", c_char_p),
+ ("org_id", c_char_p),
+ ("secure_internet_home", c_char_p),
+ ("keyword_list", c_char_p),
+ ]
+
+class cDiscoveryOrganizations(Structure):
+ _fields_ = [
+ ("version", c_ulonglong),
+ ("organizations", POINTER(POINTER(cDiscoveryOrganization))),
+ ("total_organizations", c_size_t),
+ ]
+
+class cServerProfile(Structure):
+ _fields_ = [
+ ("identifier", c_char_p),
+ ("display_name", c_char_p),
+ ("default_gateway", c_int),
+ ]
+
+class cServerProfiles(Structure):
+ _fields_ = [
+ ("current", c_int),
+ ("profiles", POINTER(POINTER(cServerProfile))),
+ ("total_profiles", c_size_t),
+ ]
+
+class cServer(Structure):
+ _fields_ = [
+ ("identifier", c_char_p),
+ ("display_name", c_char_p),
+ ("country_code", c_char_p),
+ ("support_contact", POINTER(c_char_p)),
+ ("total_support_contact", c_size_t),
+ ("profiles", POINTER(cServerProfiles)),
+ ("expire_time", c_ulonglong),
+ ]
+
+class cServers(Structure):
+ _fields_ = [
+ ("custom_servers", POINTER(POINTER(cServer))),
+ ("total_custom", c_size_t),
+ ("institute_servers", POINTER(POINTER(cServer))),
+ ("total_institute", c_size_t),
+ ("secure_internet", POINTER(cServer)),
+ ]
class DataError(Structure):
_fields_ = [("data", c_void_p), ("error", c_void_p)]
-VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_char_p)
+VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p)
# Exposed functions
# We have to use c_void_p instead of c_char_p to free it properly
@@ -77,7 +131,7 @@ lib.Register.argtypes, lib.Register.restype = [
], c_void_p
lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [
c_char_p
-], DataError
+], c_void_p
lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError
lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None
lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p
@@ -96,8 +150,12 @@ lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c
lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p
lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int
lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p
+lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None
lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
+lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [c_void_p], None
+lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None
lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int
+lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], c_void_p
class WrappedError:
@@ -139,6 +197,8 @@ def get_ptr_error(ptr: c_void_p) -> Optional[WrappedError]:
if not error_json:
return None
+ if "level" not in error_json:
+ return error_string
level = error_json["level"]
traceback = error_json["traceback"]
cause = error_json["cause"]
@@ -149,6 +209,9 @@ def get_error(ptr: c_void_p) -> str:
error = get_ptr_error(ptr)
if not error:
return ""
+
+ if not isinstance(error, WrappedError):
+ return error
return error.cause
@@ -161,7 +224,6 @@ def get_data_error(data_error: DataError) -> Tuple[str, str]:
def get_bool(boolInt: c_int) -> bool:
return boolInt == 1
-
decode_map = {
c_int: get_bool,
c_void_p: get_error,
diff --git a/wrappers/python/src/discovery.py b/wrappers/python/src/discovery.py
new file mode 100644
index 0000000..80c08cf
--- /dev/null
+++ b/wrappers/python/src/discovery.py
@@ -0,0 +1,43 @@
+from . import lib, cDiscoveryOrganizations
+from ctypes import cast, POINTER
+
+
+class DiscoOrganization:
+ def __init__(self, display_name, org_id, secure_internet_home, keyword_list):
+ self.display_name = display_name
+ self.org_id = org_id
+ self.secure_internet_home = secure_internet_home
+ self.keyword_list = keyword_list
+
+
+class DiscoOrganizations:
+ def __init__(self, version, organizations):
+ self.version = version
+ self.organizations = organizations
+
+
+def get_disco_organization(ptr):
+ if not ptr:
+ return None
+
+ current_organization = ptr.contents
+ display_name = current_organization.display_name.decode("utf-8")
+ org_id = current_organization.org_id.decode("utf-8")
+ secure_internet_home = current_organization.secure_internet_home.decode("utf-8")
+ keyword_list = current_organization.keyword_list.decode("utf-8")
+ return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list)
+
+
+def get_disco_organizations(ptr):
+ if ptr:
+ orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents
+ organizations = []
+ if orgs.organizations:
+ for i in range(orgs.total_organizations):
+ current = get_disco_organization(orgs.organizations[i])
+ if current is None:
+ continue
+ organizations.append(current)
+ lib.FreeDiscoOrganizations(ptr)
+ return DiscoOrganizations(orgs.version, organizations)
+ return None
diff --git a/wrappers/python/src/event.py b/wrappers/python/src/event.py
index d0740f8..0e0f5ae 100644
--- a/wrappers/python/src/event.py
+++ b/wrappers/python/src/event.py
@@ -1,7 +1,8 @@
-from . import VPNStateChange
+from . import VPNStateChange, get_ptr_string
from enum import Enum
from typing import Callable
-from .state import StateType
+from .state import State, StateType
+from .server import get_locations, get_servers
EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback"
@@ -15,6 +16,15 @@ def class_state_transition(state: int, state_type: StateType) -> Callable:
return wrapper
+def convert_data(state: State, data):
+ if not data:
+ return None
+ if state is State.NO_SERVER:
+ return get_servers(data)
+ if state is State.OAUTH_STARTED:
+ return get_ptr_string(data)
+ if state is State.ASK_LOCATION:
+ return get_locations(data)
class EventHandler(object):
def __init__(self):
@@ -73,6 +83,7 @@ class EventHandler(object):
def run(self, old_state: int, new_state: int, data: str) -> None:
# First run leave transitions, then enter
# The state is done when the wait event finishes
- self.run_state(old_state, new_state, StateType.Leave, data)
- self.run_state(new_state, old_state, StateType.Enter, data)
- self.run_state(new_state, old_state, StateType.Wait, data)
+ converted = convert_data(new_state, data)
+ self.run_state(old_state, new_state, StateType.Leave, converted)
+ self.run_state(new_state, old_state, StateType.Enter, converted)
+ self.run_state(new_state, old_state, StateType.Wait, converted)
diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py
index b37842f..cbeadb5 100644
--- a/wrappers/python/src/main.py
+++ b/wrappers/python/src/main.py
@@ -1,8 +1,10 @@
from . import lib, VPNStateChange, encode_args, decode_res
from typing import Optional, Tuple
import threading
+from .discovery import get_disco_organizations
from .event import EventHandler
from .state import State, StateType
+from .server import get_servers
import json
eduvpn_objects = {}
@@ -26,7 +28,7 @@ def state_callback(name, old_state, new_state, data):
name = name.decode()
if name not in eduvpn_objects:
return
- eduvpn_objects[name].callback(State(old_state), State(new_state), data.decode())
+ eduvpn_objects[name].callback(State(old_state), State(new_state), data)
class EduVPN(object):
@@ -58,6 +60,12 @@ class EduVPN(object):
res = func(self.name.encode("utf-8"), *(args_gen))
return decode_res(func.restype)(res)
+ def go_function_custom_decode(self, func, decode_func, *args):
+ # The functions all have at least one arg type which is the name of the client
+ args_gen = encode_args(list(args), func.argtypes[1:])
+ res = func(self.name.encode("utf-8"), *(args_gen))
+ return decode_func(res)
+
def cancel_oauth(self) -> None:
cancel_oauth_err = self.go_function(lib.CancelOAuth)
@@ -91,10 +99,9 @@ class EduVPN(object):
return servers
def get_disco_organizations(self) -> str:
- organizations, organizations_err = self.go_function(lib.GetDiscoOrganizations)
-
- if organizations_err:
- raise Exception(organizations_err)
+ organizations = self.go_function_custom_decode(lib.GetDiscoOrganizations, decode_func=get_disco_organizations)
+ #if organizations_err:
+ # raise Exception(organizations_err)
return organizations
@@ -196,7 +203,7 @@ class EduVPN(object):
def event(self) -> EventHandler:
return self.event_handler
- def callback(self, old_state: State, new_state: State, data: str) -> None:
+ def callback(self, old_state: State, new_state: State, data) -> None:
self.event.run(old_state, new_state, data)
def set_profile(self, profile_id: str) -> None:
@@ -242,3 +249,9 @@ class EduVPN(object):
def in_fsm_state(self, state_id: State) -> bool:
return self.go_function(lib.InFSMState, state_id)
+
+ def get_saved_servers_old(self) -> str:
+ return self.go_function(lib.GetSavedServersOLD)
+
+ def get_saved_servers_new(self) -> str:
+ return self.go_function_custom_decode(lib.GetSavedServersNEW, decode_func=get_servers)
diff --git a/wrappers/python/src/server.py b/wrappers/python/src/server.py
new file mode 100644
index 0000000..b765ede
--- /dev/null
+++ b/wrappers/python/src/server.py
@@ -0,0 +1,158 @@
+from . import lib, cServers, cServerLocations
+from ctypes import cast, POINTER
+
+
+class Profile:
+ def __init__(self, identifier, display_name, default_gateway: bool):
+ self.identifier = identifier
+ self.display_name = display_name
+ self.default_gateway = default_gateway
+
+ def __str__(self):
+ return f"Profile: {self.display_name}"
+
+
+class Server:
+ def __init__(self, url, display_name, profiles, current_profile, expire_time):
+ self.url = url
+ self.display_name = display_name
+ self.profiles = profiles
+ self.current_profile = None
+ if current_profile < len(profiles):
+ self.current_profile = profiles[current_profile]
+ self.expire_time = expire_time
+
+ def __str__(self):
+ return f"Server: {self.url}, with current profile: {self.current_profile}"
+
+
+class InstituteServer(Server):
+ def __init__(
+ self, url, display_name, support_contact, profiles, current_profile, expire_time
+ ):
+ super().__init__(url, display_name, profiles, current_profile, expire_time)
+ self.support_contact = support_contact
+
+ def __str__(self):
+ return f"Institute Server: {self.display_name}"
+
+
+class SecureInternetServer(Server):
+ def __init__(
+ self,
+ url,
+ display_name,
+ support_contact,
+ profiles,
+ current_profile,
+ expire_time,
+ country_code,
+ ):
+ super().__init__(url, display_name, profiles, current_profile, expire_time)
+ self.support_contact = support_contact
+ self.country_code = country_code
+
+ def __str__(self):
+ return f"Secure Internet Server: {self.display_name} with country {self.country_code}"
+
+
+def get_type_for_str(type_str: str):
+ if type_str is "secure_internet":
+ return SecureInternetServer
+ if type_str is "custom_server":
+ return Server
+ return InstituteServer
+
+
+def get_server(ptr, _type=None):
+ if not ptr:
+ return None
+
+ current_server = ptr.contents
+ if _type is None:
+ _type = get_type_for_str(current_server.server_type.decode("utf-8"))
+
+ identifier = current_server.identifier.decode("utf-8")
+ display_name = current_server.display_name.decode("utf-8")
+
+ if _type is not Server:
+ support_contact = []
+ for i in range(current_server.total_support_contact):
+ support_contact.append(current_server.support_contact[i].decode("utf-8"))
+ profiles = []
+ if not current_server.profiles:
+ return None
+
+ _profiles = current_server.profiles.contents
+ current_profile = _profiles.current
+ for i in range(_profiles.total_profiles):
+ if not _profiles.profiles or not _profiles.profiles[i]:
+ return None
+ profile = _profiles.profiles[i].contents
+ profiles.append(
+ Profile(
+ profile.identifier.decode("utf-8"),
+ profile.display_name.decode("utf-8"),
+ profile.default_gateway == 1,
+ )
+ )
+
+ if _type is SecureInternetServer:
+ return SecureInternetServer(
+ identifier,
+ display_name,
+ support_contact,
+ profiles,
+ current_profile,
+ current_server.expire_time,
+ current_server.country_code.decode("utf-8"),
+ )
+ if _type is InstituteServer:
+ return InstituteServer(
+ identifier,
+ display_name,
+ support_contact,
+ profiles,
+ current_profile,
+ current_server.expire_time,
+ )
+ return Server(
+ identifier, display_name, profiles, current_profile, current_server.expire_time
+ )
+
+
+def get_servers(ptr):
+ if ptr:
+ returned = []
+ servers = cast(ptr, POINTER(cServers)).contents
+ if servers.custom_servers:
+ for i in range(servers.total_custom):
+ current = get_server(servers.custom_servers[i], Server)
+ if current is None:
+ continue
+ returned.append(current)
+
+ if servers.institute_servers:
+ for i in range(servers.total_institute):
+ current = get_server(servers.institute_servers[i], InstituteServer)
+ if current is None:
+ continue
+ returned.append(current)
+
+ if servers.secure_internet:
+ current = get_server(servers.secure_internet, SecureInternetServer)
+ if current is not None:
+ returned.append(current)
+ lib.FreeServers(ptr)
+ return returned
+ return None
+
+def get_locations(ptr):
+ if ptr:
+ locations = cast(ptr, POINTER(cServerLocations)).contents
+ location_list = []
+ for i in range(locations.total_locations):
+ location_list.append(locations.locations[i].decode("utf-8"))
+ lib.FreeSecureLocations(ptr)
+ return location_list
+ return None
diff --git a/wrappers/python/tests.py b/wrappers/python/tests.py
index 60d3cce..679eda0 100644
--- a/wrappers/python/tests.py
+++ b/wrappers/python/tests.py
@@ -24,7 +24,7 @@ class ConfigTests(unittest.TestCase):
@_eduvpn.event.on(State.OAUTH_STARTED, StateType.Enter)
def oauth_initialized(old_state, url_json):
- login_eduvpn(json.loads(url_json))
+ login_eduvpn(url_json)
server_uri = os.getenv("SERVER_URI")
if not server_uri: