summaryrefslogtreecommitdiff
path: root/wrappers/python
diff options
context:
space:
mode:
Diffstat (limited to 'wrappers/python')
-rw-r--r--wrappers/python/src/__init__.py73
-rw-r--r--wrappers/python/src/discovery.py85
-rw-r--r--wrappers/python/src/event.py26
-rw-r--r--wrappers/python/src/main.py20
-rw-r--r--wrappers/python/src/server.py109
5 files changed, 247 insertions, 66 deletions
diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py
index ece4b46..ca07143 100644
--- a/wrappers/python/src/__init__.py
+++ b/wrappers/python/src/__init__.py
@@ -5,6 +5,7 @@ import pathlib
import platform
from typing import Tuple, Optional
import json
+from typing import List
_lib_prefixes = defaultdict(
lambda: "lib",
@@ -40,11 +41,10 @@ class ErrorLevel(Enum):
ERR_OTHER = 0
ERR_INFO = 1
+
class cServerLocations(Structure):
- _fields_ = [
- ("locations", POINTER(c_char_p)),
- ("total_locations", c_size_t)
- ]
+ _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)]
+
class cDiscoveryOrganization(Structure):
_fields_ = [
@@ -54,6 +54,7 @@ class cDiscoveryOrganization(Structure):
("keyword_list", c_char_p),
]
+
class cDiscoveryOrganizations(Structure):
_fields_ = [
("version", c_ulonglong),
@@ -61,6 +62,30 @@ class cDiscoveryOrganizations(Structure):
("total_organizations", c_size_t),
]
+
+class cDiscoveryServer(Structure):
+ _fields_ = [
+ ("authentication_url_template", c_char_p),
+ ("base_url", c_char_p),
+ ("country_code", c_char_p),
+ ("display_name", c_char_p),
+ ("keyword_list", c_char_p),
+ ("public_key_list", POINTER(c_char_p)),
+ ("total_public_keys", c_size_t),
+ ("server_type", c_char_p),
+ ("support_contact", POINTER(c_char_p)),
+ ("total_support_contact", c_size_t),
+ ]
+
+
+class cDiscoveryServers(Structure):
+ _fields_ = [
+ ("version", c_ulonglong),
+ ("servers", POINTER(POINTER(cDiscoveryServer))),
+ ("total_servers", c_size_t),
+ ]
+
+
class cServerProfile(Structure):
_fields_ = [
("identifier", c_char_p),
@@ -68,6 +93,7 @@ class cServerProfile(Structure):
("default_gateway", c_int),
]
+
class cServerProfiles(Structure):
_fields_ = [
("current", c_int),
@@ -75,10 +101,12 @@ class cServerProfiles(Structure):
("total_profiles", c_size_t),
]
+
class cServer(Structure):
_fields_ = [
("identifier", c_char_p),
("display_name", c_char_p),
+ ("server_type", c_char_p),
("country_code", c_char_p),
("support_contact", POINTER(c_char_p)),
("total_support_contact", c_size_t),
@@ -86,6 +114,7 @@ class cServer(Structure):
("expire_time", c_ulonglong),
]
+
class cServers(Structure):
_fields_ = [
("custom_servers", POINTER(POINTER(cServer))),
@@ -95,6 +124,7 @@ class cServers(Structure):
("secure_internet", POINTER(cServer)),
]
+
class DataError(Structure):
_fields_ = [("data", c_void_p), ("error", c_void_p)]
@@ -104,9 +134,17 @@ VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p)
# Exposed functions
# We have to use c_void_p instead of c_char_p to free it properly
# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error
-lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [c_char_p], c_void_p
-lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [c_char_p, c_char_p], c_void_p
-lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [c_char_p, c_char_p], c_void_p
+lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [
+ c_char_p
+], c_void_p
+lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [
+ c_char_p,
+ c_char_p,
+], c_void_p
+lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [
+ c_char_p,
+ c_char_p,
+], c_void_p
lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [
c_char_p,
c_char_p,
@@ -132,7 +170,7 @@ lib.Register.argtypes, lib.Register.restype = [
lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [
c_char_p
], c_void_p
-lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError
+lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], c_void_p
lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None
lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p
lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p
@@ -150,9 +188,14 @@ lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c
lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p
lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int
lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p
+lib.FreeProfiles.argtypes, lib.FreeProfiles.restype = [c_void_p], None
lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None
lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
-lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [c_void_p], None
+lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [
+ c_void_p
+], None
+lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None
+lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None
lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None
lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int
lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], c_void_p
@@ -186,6 +229,17 @@ def get_ptr_string(ptr: c_void_p) -> str:
return ""
+def get_ptr_list_strings(
+ strings: POINTER(c_char_p), total_strings: c_size_t
+) -> List[str]:
+ if strings:
+ strings_list = []
+ for i in range(total_strings):
+ strings_list.append(strings[i].decode("utf-8"))
+ return strings_list
+ return []
+
+
def get_ptr_error(ptr: c_void_p) -> Optional[WrappedError]:
error_string = get_ptr_string(ptr)
@@ -224,6 +278,7 @@ def get_data_error(data_error: DataError) -> Tuple[str, str]:
def get_bool(boolInt: c_int) -> bool:
return boolInt == 1
+
decode_map = {
c_int: get_bool,
c_void_p: get_error,
diff --git a/wrappers/python/src/discovery.py b/wrappers/python/src/discovery.py
index 80c08cf..a1b87ea 100644
--- a/wrappers/python/src/discovery.py
+++ b/wrappers/python/src/discovery.py
@@ -1,4 +1,4 @@
-from . import lib, cDiscoveryOrganizations
+from . import lib, cDiscoveryOrganizations, cDiscoveryServers, get_ptr_list_strings
from ctypes import cast, POINTER
@@ -9,6 +9,9 @@ class DiscoOrganization:
self.secure_internet_home = secure_internet_home
self.keyword_list = keyword_list
+ def __str__(self):
+ return self.display_name
+
class DiscoOrganizations:
def __init__(self, version, organizations):
@@ -16,6 +19,37 @@ class DiscoOrganizations:
self.organizations = organizations
+class DiscoServer:
+ def __init__(
+ self,
+ authentication_url_template,
+ base_url,
+ country_code,
+ display_name,
+ keyword_list,
+ public_keys,
+ server_type,
+ support_contacts,
+ ):
+ self.authentication_url_template = authentication_url_template
+ self.base_url = base_url
+ self.country_code = country_code
+ self.display_name = display_name
+ self.keyword_list = keyword_list
+ self.public_keys = public_keys
+ self.server_type = server_type
+ self.support_contacts = support_contacts
+
+ def __str__(self):
+ return self.display_name
+
+
+class DiscoServers:
+ def __init__(self, version, servers):
+ self.version = version
+ self.servers = servers
+
+
def get_disco_organization(ptr):
if not ptr:
return None
@@ -28,6 +62,55 @@ def get_disco_organization(ptr):
return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list)
+def get_disco_server(ptr):
+ if not ptr:
+ return None
+
+ current_server = ptr.contents
+ authentication_url_template = current_server.authentication_url_template.decode(
+ "utf-8"
+ )
+ base_url = current_server.base_url.decode("utf-8")
+ country_code = current_server.country_code.decode("utf-8")
+ display_name = current_server.display_name.decode("utf-8")
+ keyword_list = current_server.keyword_list.decode("utf-8")
+ public_keys = get_ptr_list_strings(
+ current_server.public_key_list, current_server.total_public_keys
+ )
+ server_type = current_server.server_type.decode("utf-8")
+ support_contacts = get_ptr_list_strings(
+ current_server.support_contact, current_server.total_support_contact
+ )
+ return DiscoServer(
+ authentication_url_template,
+ base_url,
+ country_code,
+ display_name,
+ keyword_list,
+ public_keys,
+ server_type,
+ support_contacts,
+ )
+
+
+def get_disco_servers(ptr):
+ if ptr:
+ svrs = cast(ptr, POINTER(cDiscoveryServers)).contents
+
+ servers = []
+
+ if svrs.servers:
+ for i in range(svrs.total_servers):
+ current = get_disco_server(svrs.servers[i])
+
+ if current is None:
+ continue
+ servers.append(current)
+ lib.FreeDiscoServers(ptr)
+ return DiscoServers(svrs.version, servers)
+ return None
+
+
def get_disco_organizations(ptr):
if ptr:
orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents
diff --git a/wrappers/python/src/event.py b/wrappers/python/src/event.py
index 0e0f5ae..cf1a9d0 100644
--- a/wrappers/python/src/event.py
+++ b/wrappers/python/src/event.py
@@ -2,7 +2,12 @@ from . import VPNStateChange, get_ptr_string
from enum import Enum
from typing import Callable
from .state import State, StateType
-from .server import get_locations, get_servers
+from .server import (
+ get_locations,
+ get_transition_profiles,
+ get_transition_server,
+ get_servers,
+)
EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback"
@@ -16,6 +21,7 @@ def class_state_transition(state: int, state_type: StateType) -> Callable:
return wrapper
+
def convert_data(state: State, data):
if not data:
return None
@@ -25,6 +31,16 @@ def convert_data(state: State, data):
return get_ptr_string(data)
if state is State.ASK_LOCATION:
return get_locations(data)
+ if state is State.ASK_PROFILE:
+ return get_transition_profiles(data)
+ if state in [
+ State.DISCONNECTED,
+ State.DISCONNECTING,
+ State.CONNECTING,
+ State.CONNECTED,
+ ]:
+ return get_transition_server(data)
+
class EventHandler(object):
def __init__(self):
@@ -80,10 +96,14 @@ class EventHandler(object):
for func in self.handlers[(state, state_type)]:
func(other_state, data)
- def run(self, old_state: int, new_state: int, data: str) -> None:
+ def run(
+ self, old_state: int, new_state: int, data: str, convert: bool = True
+ ) -> None:
# First run leave transitions, then enter
# The state is done when the wait event finishes
- converted = convert_data(new_state, data)
+ converted = data
+ if convert:
+ converted = convert_data(new_state, data)
self.run_state(old_state, new_state, StateType.Leave, converted)
self.run_state(new_state, old_state, StateType.Enter, converted)
self.run_state(new_state, old_state, StateType.Wait, converted)
diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py
index 8425992..e20b84f 100644
--- a/wrappers/python/src/main.py
+++ b/wrappers/python/src/main.py
@@ -1,7 +1,7 @@
from . import lib, VPNStateChange, encode_args, decode_res
from typing import Optional, Tuple
import threading
-from .discovery import get_disco_organizations
+from .discovery import get_disco_organizations, get_disco_servers
from .event import EventHandler
from .state import State, StateType
from .server import get_servers
@@ -88,16 +88,20 @@ class EduVPN(object):
raise Exception(register_err)
def get_disco_servers(self) -> str:
- servers, servers_err = self.go_function(lib.GetDiscoServers)
+ servers = self.go_function_custom_decode(
+ lib.GetDiscoServers, decode_func=get_disco_servers
+ )
- if servers_err:
- raise Exception(servers_err)
+ # if servers_err:
+ # raise Exception(servers_err)
return servers
def get_disco_organizations(self) -> str:
- organizations = self.go_function_custom_decode(lib.GetDiscoOrganizations, decode_func=get_disco_organizations)
- #if organizations_err:
+ organizations = self.go_function_custom_decode(
+ lib.GetDiscoOrganizations, decode_func=get_disco_organizations
+ )
+ # if organizations_err:
# raise Exception(organizations_err)
return organizations
@@ -251,4 +255,6 @@ class EduVPN(object):
return self.go_function(lib.GetSavedServersOLD)
def get_saved_servers_new(self) -> str:
- return self.go_function_custom_decode(lib.GetSavedServersNEW, decode_func=get_servers)
+ return self.go_function_custom_decode(
+ lib.GetSavedServersNEW, decode_func=get_servers
+ )
diff --git a/wrappers/python/src/server.py b/wrappers/python/src/server.py
index b765ede..dce5d51 100644
--- a/wrappers/python/src/server.py
+++ b/wrappers/python/src/server.py
@@ -1,5 +1,6 @@
-from . import lib, cServers, cServerLocations
-from ctypes import cast, POINTER
+from . import lib, cServer, cServers, cServerLocations, cServerProfiles
+from ctypes import cast, POINTER, c_char_p
+from datetime import datetime
class Profile:
@@ -9,61 +10,85 @@ class Profile:
self.default_gateway = default_gateway
def __str__(self):
- return f"Profile: {self.display_name}"
+ return self.display_name
+
+
+class Profiles:
+ def __init__(self, profiles, current):
+ self.profiles = profiles
+ self.current_index = current
+
+ @property
+ def current(self):
+ if self.current_index < len(self.profiles):
+ return self.profiles[self.current_index]
+ return None
class Server:
- def __init__(self, url, display_name, profiles, current_profile, expire_time):
+ def __init__(self, url, display_name, profiles=None, expire_time=0):
self.url = url
self.display_name = display_name
self.profiles = profiles
self.current_profile = None
- if current_profile < len(profiles):
- self.current_profile = profiles[current_profile]
- self.expire_time = expire_time
+ self.expire_time = datetime.fromtimestamp(expire_time)
def __str__(self):
- return f"Server: {self.url}, with current profile: {self.current_profile}"
+ return self.display_name
class InstituteServer(Server):
- def __init__(
- self, url, display_name, support_contact, profiles, current_profile, expire_time
- ):
- super().__init__(url, display_name, profiles, current_profile, expire_time)
+ def __init__(self, url, display_name, support_contact, profiles, expire_time):
+ super().__init__(url, display_name, profiles, expire_time)
self.support_contact = support_contact
- def __str__(self):
- return f"Institute Server: {self.display_name}"
-
class SecureInternetServer(Server):
def __init__(
self,
- url,
+ org_id,
display_name,
support_contact,
profiles,
- current_profile,
expire_time,
country_code,
):
- super().__init__(url, display_name, profiles, current_profile, expire_time)
+ super().__init__(org_id, display_name, profiles, expire_time)
+ self.org_id = org_id
self.support_contact = support_contact
self.country_code = country_code
- def __str__(self):
- return f"Secure Internet Server: {self.display_name} with country {self.country_code}"
-
def get_type_for_str(type_str: str):
- if type_str is "secure_internet":
+ if type_str == "secure_internet":
return SecureInternetServer
- if type_str is "custom_server":
+ if type_str == "custom_server":
return Server
return InstituteServer
+def get_profiles(ptr):
+ if not ptr:
+ return []
+ profiles = []
+ _profiles = ptr.contents
+ current_profile = _profiles.current
+ if not _profiles.profiles:
+ return []
+ for i in range(_profiles.total_profiles):
+ if not _profiles.profiles[i]:
+ continue
+ profile = _profiles.profiles[i].contents
+ profiles.append(
+ Profile(
+ profile.identifier.decode("utf-8"),
+ profile.display_name.decode("utf-8"),
+ profile.default_gateway == 1,
+ )
+ )
+ return Profiles(profiles, current_profile)
+
+
def get_server(ptr, _type=None):
if not ptr:
return None
@@ -79,31 +104,13 @@ def get_server(ptr, _type=None):
support_contact = []
for i in range(current_server.total_support_contact):
support_contact.append(current_server.support_contact[i].decode("utf-8"))
- profiles = []
- if not current_server.profiles:
- return None
-
- _profiles = current_server.profiles.contents
- current_profile = _profiles.current
- for i in range(_profiles.total_profiles):
- if not _profiles.profiles or not _profiles.profiles[i]:
- return None
- profile = _profiles.profiles[i].contents
- profiles.append(
- Profile(
- profile.identifier.decode("utf-8"),
- profile.display_name.decode("utf-8"),
- profile.default_gateway == 1,
- )
- )
-
+ profiles = get_profiles(current_server.profiles)
if _type is SecureInternetServer:
return SecureInternetServer(
identifier,
display_name,
support_contact,
profiles,
- current_profile,
current_server.expire_time,
current_server.country_code.decode("utf-8"),
)
@@ -113,12 +120,21 @@ def get_server(ptr, _type=None):
display_name,
support_contact,
profiles,
- current_profile,
current_server.expire_time,
)
- return Server(
- identifier, display_name, profiles, current_profile, current_server.expire_time
- )
+ return Server(identifier, display_name, profiles, current_server.expire_time)
+
+
+def get_transition_server(ptr):
+ server = get_server(cast(ptr, POINTER(cServer)))
+ lib.FreeServer(ptr)
+ return server
+
+
+def get_transition_profiles(ptr):
+ profiles = get_profiles(cast(ptr, POINTER(cServerProfiles)))
+ lib.FreeProfiles(ptr)
+ return profiles
def get_servers(ptr):
@@ -147,6 +163,7 @@ def get_servers(ptr):
return returned
return None
+
def get_locations(ptr):
if ptr:
locations = cast(ptr, POINTER(cServerLocations)).contents