summaryrefslogtreecommitdiff
path: root/wrappers/python/eduvpn_common/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'wrappers/python/eduvpn_common/main.py')
-rw-r--r--wrappers/python/eduvpn_common/main.py81
1 files changed, 40 insertions, 41 deletions
diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py
index 382f356..1271fb2 100644
--- a/wrappers/python/eduvpn_common/main.py
+++ b/wrappers/python/eduvpn_common/main.py
@@ -1,4 +1,5 @@
-from typing import Optional, Tuple
+from ctypes import c_char_p, c_int, c_void_p
+from typing import Any, Callable, Dict, Optional, Tuple
import threading
from eduvpn_common.discovery import get_disco_organizations, get_disco_servers
from eduvpn_common.event import EventHandler
@@ -7,30 +8,6 @@ from eduvpn_common.types import VPNStateChange, encode_args, decode_res, get_dat
from eduvpn_common.server import get_servers
from eduvpn_common.state import State, StateType
-eduvpn_objects = {}
-
-
-def add_as_global_object(eduvpn) -> bool:
- global eduvpn_objects
- if eduvpn.name not in eduvpn_objects:
- eduvpn_objects[eduvpn.name] = eduvpn
- return True
- return False
-
-
-def remove_as_global_object(eduvpn):
- global eduvpn_objects
- eduvpn_objects.pop(eduvpn.name, None)
-
-
-@VPNStateChange
-def state_callback(name, old_state, new_state, data):
- name = name.decode()
- if name not in eduvpn_objects:
- return
- eduvpn_objects[name].callback(State(old_state), State(new_state), data)
-
-
class EduVPN(object):
def __init__(self, name: str, config_directory: str, language: str):
self.name = name
@@ -60,17 +37,14 @@ class EduVPN(object):
if self.location_event:
self.location_event.wait()
- def go_function(self, func, *args):
+ def go_function(self, func: Any, *args, decode_func: Optional[Callable] = None):
# The functions all have at least one arg type which is the name of the client
args_gen = encode_args(list(args), func.argtypes[1:])
res = func(self.name.encode("utf-8"), *(args_gen))
- return decode_res(func.restype)(self.lib, res)
-
- def go_function_custom_decode(self, func, decode_func, *args):
- # The functions all have at least one arg type which is the name of the client
- args_gen = encode_args(list(args), func.argtypes[1:])
- res = func(self.name.encode("utf-8"), *(args_gen))
- return decode_func(self.lib, res)
+ if decode_func is None:
+ return decode_res(func.restype)(self.lib, res)
+ else:
+ return decode_func(self.lib, res)
def cancel_oauth(self) -> None:
cancel_oauth_err = self.go_function(self.lib.CancelOAuth)
@@ -94,7 +68,7 @@ class EduVPN(object):
raise register_err
def get_disco_servers(self) -> str:
- servers, servers_err = self.go_function_custom_decode(
+ servers, servers_err = self.go_function(
self.lib.GetDiscoServers,
decode_func=lambda lib, x: get_data_error(lib, x, get_disco_servers),
)
@@ -105,7 +79,7 @@ class EduVPN(object):
return servers
def get_disco_organizations(self) -> str:
- organizations, organizations_err = self.go_function_custom_decode(
+ organizations, organizations_err = self.go_function(
self.lib.GetDiscoOrganizations,
decode_func=lambda lib, x: get_data_error(lib, x, get_disco_organizations),
)
@@ -152,7 +126,7 @@ class EduVPN(object):
if remove_err:
raise remove_err
- def get_config(self, url: str, func: callable, prefer_tcp: bool = False):
+ def get_config(self, url: str, func: Any, prefer_tcp: bool = False):
# Because it could be the case that a profile callback is started, store a threading event
# In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set
# The event is set in self.set_profile
@@ -205,7 +179,7 @@ class EduVPN(object):
if connecting_err:
raise connecting_err
- def set_disconnected(self, cleanup=True) -> None:
+ def set_disconnected(self, cleanup: bool = True) -> None:
disconnect_err = self.go_function(self.lib.SetDisconnected, cleanup)
if disconnect_err:
@@ -217,17 +191,17 @@ class EduVPN(object):
if search_err:
raise search_err
- def remove_class_callbacks(self, cls) -> None:
+ def remove_class_callbacks(self, cls: Any) -> None:
self.event_handler.change_class_callbacks(cls, add=False)
- def register_class_callbacks(self, cls) -> None:
+ def register_class_callbacks(self, cls: Any) -> None:
self.event_handler.change_class_callbacks(cls)
@property
def event(self) -> EventHandler:
return self.event_handler
- def callback(self, old_state: State, new_state: State, data) -> None:
+ def callback(self, old_state: State, new_state: State, data: Any) -> None:
self.event.run(old_state, new_state, data)
def set_profile(self, profile_id: str) -> None:
@@ -275,7 +249,7 @@ class EduVPN(object):
return self.go_function(self.lib.InFSMState, state_id)
def get_saved_servers(self):
- servers, servers_err = self.go_function_custom_decode(
+ servers, servers_err = self.go_function(
self.lib.GetSavedServers,
decode_func=lambda lib, x: get_data_error(lib, x, get_servers),
)
@@ -284,3 +258,28 @@ class EduVPN(object):
raise servers_err
return servers
+
+eduvpn_objects: Dict[str, EduVPN] = {}
+
+
+@VPNStateChange
+def state_callback(name: bytes, old_state: int, new_state: int, data: Any):
+ name_decoded = name.decode()
+ if name_decoded not in eduvpn_objects:
+ return
+ eduvpn_objects[name_decoded].callback(State(old_state), State(new_state), data)
+
+
+
+def add_as_global_object(eduvpn: EduVPN) -> bool:
+ global eduvpn_objects
+ if eduvpn.name not in eduvpn_objects:
+ eduvpn_objects[eduvpn.name] = eduvpn
+ return True
+ return False
+
+
+def remove_as_global_object(eduvpn: EduVPN):
+ global eduvpn_objects
+ eduvpn_objects.pop(eduvpn.name, None)
+