summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--exports/exports.go26
-rw-r--r--wrappers/python/src/__init__.py37
-rw-r--r--wrappers/python/src/event.py86
-rw-r--r--wrappers/python/src/main.py256
4 files changed, 180 insertions, 225 deletions
diff --git a/exports/exports.go b/exports/exports.go
index 1081754..567e189 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -3,12 +3,12 @@ package main
/*
#include <stdlib.h>
-typedef void (*PythonCB)(const char* oldstate, const char* newstate, const char* data);
+typedef void (*PythonCB)(const char* name, const char* oldstate, const char* newstate, const char* data);
__attribute__((weak))
-void call_callback(PythonCB callback, const char* oldstate, const char* newstate, const char* data)
+void call_callback(PythonCB callback, const char *name, const char* oldstate, const char* newstate, const char* data)
{
- callback(oldstate, newstate, data);
+ callback(name, oldstate, newstate, data);
}
*/
import "C"
@@ -21,18 +21,21 @@ import (
"github.com/jwijenbergh/eduvpn-common"
)
-var P_StateCallback C.PythonCB
+var P_StateCallbacks map[string]C.PythonCB
var VPNStates map[string]*eduvpn.VPNState
-func StateCallback(old_state string, new_state string, data string) {
- if P_StateCallback == nil {
+func StateCallback(name string, old_state string, new_state string, data string) {
+ P_StateCallback, exists := P_StateCallbacks[name]
+ if !exists || P_StateCallback == nil {
return
}
+ name_c := C.CString(name)
oldState_c := C.CString(old_state)
newState_c := C.CString(new_state)
data_c := C.CString(data)
- C.call_callback(P_StateCallback, oldState_c, newState_c, data_c)
+ C.call_callback(P_StateCallback, name_c, oldState_c, newState_c, data_c)
+ C.free(unsafe.Pointer(name_c))
C.free(unsafe.Pointer(oldState_c))
C.free(unsafe.Pointer(newState_c))
C.free(unsafe.Pointer(data_c))
@@ -58,9 +61,14 @@ func Register(name *C.char, config_directory *C.char, stateCallback C.PythonCB,
if VPNStates == nil {
VPNStates = make(map[string]*eduvpn.VPNState)
}
+ if P_StateCallbacks == nil {
+ P_StateCallbacks = make(map[string]C.PythonCB)
+ }
VPNStates[nameStr] = state
- P_StateCallback = stateCallback
- registerErr := state.Register(nameStr, C.GoString(config_directory), StateCallback, debug != 0)
+ P_StateCallbacks[nameStr] = stateCallback
+ registerErr := state.Register(nameStr, C.GoString(config_directory), func(old string, new string, data string) {
+ StateCallback(nameStr, old, new, data)
+ }, debug != 0)
if registerErr != nil {
delete(VPNStates, nameStr)
diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py
index d260916..8129495 100644
--- a/wrappers/python/src/__init__.py
+++ b/wrappers/python/src/__init__.py
@@ -35,7 +35,7 @@ class MultipleDataError(Structure):
_fields_ = [("data", c_void_p), ("other_data", c_void_p), ("error", c_void_p)]
-VPNStateChange = CFUNCTYPE(None, c_char_p, c_char_p, c_char_p)
+VPNStateChange = CFUNCTYPE(None, c_char_p, c_char_p, c_char_p, c_char_p)
# Exposed functions
# We have to use c_void_p instead of c_char_p to free it properly
@@ -67,7 +67,19 @@ lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p
lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
-def GetPtrString(ptr: c_void_p) -> str:
+def encode_args(args, types):
+ for arg, t in zip(args, types):
+ # c_char_p needs the str to be encoded to bytes
+ if t is c_char_p:
+ arg = arg.encode("utf-8")
+ yield arg
+
+
+def decode_res(t):
+ return decode_map.get(t, lambda x: x)
+
+
+def get_ptr_string(ptr: c_void_p) -> str:
if ptr:
string = cast(ptr, c_char_p).value
lib.FreeString(ptr)
@@ -76,16 +88,23 @@ def GetPtrString(ptr: c_void_p) -> str:
return ""
-def GetDataError(data_error: DataError) -> Tuple[str, str]:
- data = GetPtrString(data_error.data)
- error = GetPtrString(data_error.error)
+def get_data_error(data_error: DataError) -> Tuple[str, str]:
+ data = get_ptr_string(data_error.data)
+ error = get_ptr_string(data_error.error)
return data, error
-def GetMultipleDataError(
+def get_multiple_data_error(
multiple_data_error: MultipleDataError,
) -> Tuple[str, str, str]:
- data = GetPtrString(multiple_data_error.data)
- other_data = GetPtrString(multiple_data_error.other_data)
- error = GetPtrString(multiple_data_error.error)
+ data = get_ptr_string(multiple_data_error.data)
+ other_data = get_ptr_string(multiple_data_error.other_data)
+ error = get_ptr_string(multiple_data_error.error)
return data, other_data, error
+
+
+decode_map = {
+ c_void_p: get_ptr_string,
+ DataError: get_data_error,
+ MultipleDataError: get_multiple_data_error,
+}
diff --git a/wrappers/python/src/event.py b/wrappers/python/src/event.py
new file mode 100644
index 0000000..778ce5e
--- /dev/null
+++ b/wrappers/python/src/event.py
@@ -0,0 +1,86 @@
+from . import VPNStateChange
+from enum import Enum
+from typing import Callable
+
+
+class StateType(Enum):
+ Enter = 1
+ Leave = 2
+ Wait = 3
+
+
+EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback"
+
+# A state transition decorator for classes
+# To use this, make sure to register the class with `register_class_callbacks`
+def class_state_transition(state: str, state_type: StateType) -> Callable:
+ def wrapper(func):
+ setattr(func, EDUVPN_CALLBACK_PROPERTY, (state, state_type))
+ return func
+
+ return wrapper
+
+
+class EventHandler(object):
+ def __init__(self):
+ self.handlers = {}
+
+ def change_class_callbacks(self, cls, add=True) -> None:
+ # Loop over method names
+ for method_name in dir(cls):
+ try:
+ # Get the method
+ method = getattr(cls, method_name)
+ except:
+ # Unable to get a value, go to the next
+ continue
+
+ # If it has a callback defined, add it to the events
+ method_value = getattr(method, EDUVPN_CALLBACK_PROPERTY, None)
+ if method_value:
+ state, state_type = method_value
+
+ if add:
+ self.add_event(state, state_type, method)
+ else:
+ self.remove_event(state, state_type, method)
+
+ def remove_event(self, state: str, state_type: StateType, func: Callable):
+ for key, values in self.handlers.copy().items():
+ if key == (state, state_type):
+ values.remove(func)
+ if not values:
+ del self.handlers[key]
+ else:
+ self.handlers[key] = values
+
+ def add_event(self, state: str, state_type: StateType, func: Callable):
+ if (state, state_type) not in self.handlers:
+ self.handlers[(state, state_type)] = []
+ self.handlers[(state, state_type)].append(func)
+
+ # A decorator for standalone functions
+ def on(self, state: str, state_type: StateType) -> Callable:
+ def wrapped_f(func):
+ self.add_event(state, state_type, func)
+ return func
+
+ return wrapped_f
+
+ def run_state(
+ self, state: str, other_state: str, state_type: StateType, data: str
+ ) -> None:
+ if (state, state_type) not in self.handlers:
+ return
+ for func in self.handlers[(state, state_type)]:
+ func(other_state, data)
+
+ def run(self, old_state: str, new_state: str, data: str) -> None:
+ if old_state == new_state:
+ return
+
+ # First run leave transitions, then enter
+ # The state is done when the wait event finishes
+ self.run_state(old_state, new_state, StateType.Leave, data)
+ self.run_state(new_state, old_state, StateType.Enter, data)
+ self.run_state(new_state, old_state, StateType.Wait, data)
diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py
index 03757df..76a08ab 100644
--- a/wrappers/python/src/main.py
+++ b/wrappers/python/src/main.py
@@ -1,178 +1,29 @@
-from . import lib, VPNStateChange, GetDataError, GetMultipleDataError, GetPtrString
-from ctypes import *
-from enum import Enum
-from typing import Callable, Optional, Tuple
-from functools import wraps
+from . import lib, VPNStateChange, encode_args, decode_res
+from typing import Optional, Tuple
import threading
+from .event import StateType, EventHandler
+eduvpn_objects = {}
-class StateType(Enum):
- Enter = 1
- Leave = 2
- Wait = 3
-
-EDUVPN_CALLBACK_PROPERTY = '_eduvpn_property_callback'
-
-# A state transition decorator for classes
-# To use this, make sure to register the class with `register_class_callbacks`
-def class_state_transition(state: str, state_type: StateType) -> Callable:
- def wrapper(func):
- setattr(func, EDUVPN_CALLBACK_PROPERTY, (state, state_type))
- return func
- return wrapper
-
-class EventHandler(object):
- def __init__(self):
- self.handlers = {}
-
- def remove_event(self, state: str, state_type: StateType, func: Callable):
- for key, values in self.handlers.copy().items():
- if key == (state, state_type):
- values.remove(func)
- if not values:
- del self.handlers[key]
- else:
- self.handlers[key] = values
-
- def add_event(self, state: str, state_type: StateType, func: Callable):
- if (state, state_type) not in self.handlers:
- self.handlers[(state, state_type)] = []
- self.handlers[(state, state_type)].append(func)
-
- # A decorator for standalone functions
- def on(self, state: str, state_type: StateType) -> Callable:
- def wrapped_f(func):
- self.add_event(state, state_type, func)
- return func
-
- return wrapped_f
-
- def run_state(
- self, state: str, other_state: str, state_type: StateType, data: str
- ) -> None:
- if (state, state_type) not in self.handlers:
- return
- for func in self.handlers[(state, state_type)]:
- func(other_state, data)
-
- def run(self, old_state: str, new_state: str, data: str) -> None:
- if old_state == new_state:
- return
-
- # First run leave transitions, then enter
- # The state is done when the wait event finishes
- self.run_state(old_state, new_state, StateType.Leave, data)
- self.run_state(new_state, old_state, StateType.Enter, data)
- self.run_state(new_state, old_state, StateType.Wait, data)
-
-
-# Registers the python app with the Go code
-# name: The name of the app to be registered
-# state_callback: The callback to trigger whenever a state is changed
-def Register(
- name: str, config_directory: str, state_callback: Optional[Callable], debug: bool
-) -> str:
- if not state_callback:
- return "No callback provided"
- name_bytes = name.encode("utf-8")
- dir_bytes = config_directory.encode("utf-8")
- ptr_err = lib.Register(name_bytes, dir_bytes, state_callback, debug)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def CancelOAuth(name: str) -> str:
- name_bytes = name.encode("utf-8")
- ptr_err = lib.CancelOAuth(name_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def Deregister(name: str) -> str:
- name_bytes = name.encode("utf-8")
- ptr_err = lib.Deregister(name_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def GetDiscoServers(name: str) -> Tuple[str, str]:
- name_bytes = name.encode("utf-8")
- servers, servers_err = GetDataError(lib.GetServersList(name_bytes))
- return servers, servers_err
-
-
-def GetDiscoOrganizations(name: str) -> Tuple[str, str]:
- name_bytes = name.encode("utf-8")
- organizations, organizations_err = GetDataError(
- lib.GetOrganizationsList(name_bytes)
- )
- return organizations, organizations_err
-
-
-def GetConnectConfig(
- name: str, url: str, is_secure_internet: bool, force_tcp: bool
-) -> Tuple[str, str, str]:
- name_bytes = name.encode("utf-8")
- url_bytes = url.encode("utf-8")
- multiple_data_error = lib.GetConnectConfig(
- name_bytes, url_bytes, is_secure_internet, force_tcp
- )
- return GetMultipleDataError(multiple_data_error)
-
-
-def SetConnected(name: str) -> str:
- name_bytes = name.encode("utf-8")
- ptr_err = lib.SetConnected(name_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def SetDisconnected(name: str) -> str:
- name_bytes = name.encode("utf-8")
- ptr_err = lib.SetDisconnected(name_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
+def add_as_global_object(eduvpn) -> bool:
+ global eduvpn_objects
+ if eduvpn.name not in eduvpn_objects:
+ eduvpn_objects[eduvpn.name] = eduvpn
+ return True
+ return False
-def SetSearchServer(name: str) -> str:
- name_bytes = name.encode("utf-8")
- ptr_err = lib.SetSearchServer(name_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def SetIdentifier(name: str, identifier: str) -> str:
- name_bytes = name.encode("utf-8")
- identifier_bytes = identifier.encode("utf-8")
- ptr_err = lib.SetIdentifier(name_bytes, identifier_bytes)
- err_string = GetPtrString(ptr_err)
- return err_string
-
-
-def GetIdentifier(name: str) -> Tuple[str, str]:
- name_bytes = name.encode("utf-8")
- identifier, identifier_err = GetDataError(lib.GetIdentifier(name_bytes))
- return identifier, identifier_err
+def remove_as_global_object(eduvpn):
+ global eduvpn_objects
+ eduvpn_objects.pop(eduvpn.name, None)
-# This has to be global as otherwise the callback is not alive
-callback_function = None
-
-
-def register_callback(eduvpn):
- global callback_function
- callback_function = VPNStateChange(
- lambda old_state, new_state, data: eduvpn.callback(
- old_state.decode(), new_state.decode(), data.decode()
- )
- )
-
-
-def SetProfileID(name: str, profile_id: str) -> str:
- name_bytes = name.encode("utf-8")
- profile_bytes = profile_id.encode("utf-8")
- error_string = lib.SetProfileID(name_bytes, profile_bytes)
- return GetPtrString(error_string)
+@VPNStateChange
+def state_callback(name, old_state, new_state, data):
+ name = name.decode()
+ if name not in eduvpn_objects:
+ return
+ eduvpn_objects[name].callback(old_state.decode(), new_state.decode(), data.decode())
class EduVPN(object):
@@ -180,40 +31,50 @@ class EduVPN(object):
self.event_handler = EventHandler()
self.name = name
self.config_directory = config_directory
- register_callback(self)
# Callbacks that need to wait for specific events
# The ask profile callback needs to wait for the UI thread to select a profile
# This is stored in the profile_event
self.profile_event: Optional[threading.Event] = None
+
@self.event.on("Ask_Profile", StateType.Wait)
def wait_profile_event(old_state: str, profiles: str):
if self.profile_event:
self.profile_event.wait()
+ def go_function(self, func, *args):
+ # The functions all have at least one arg type which is the name of the client
+ args_gen = encode_args(list(args), func.argtypes[1:])
+ res = func(self.name.encode("utf-8"), *(args_gen))
+ return decode_res(func.restype)(res)
+
def cancel_oauth(self) -> None:
- cancel_oauth_err = CancelOAuth(self.name)
+ cancel_oauth_err = self.go_function(lib.CancelOAuth)
if cancel_oauth_err:
raise Exception(cancel_oauth_err)
def deregister(self) -> None:
- deregister_err = Deregister(self.name)
+ deregister_err = self.go_function(lib.Deregister)
+ remove_as_global_object(self)
if deregister_err:
raise Exception(deregister_err)
def register(self, debug: bool = False) -> None:
- register_err = Register(
- self.name, self.config_directory, callback_function, debug
+ if not add_as_global_object(self):
+ raise Exception("Already registered")
+
+ register_err = self.go_function(
+ lib.Register, self.config_directory, state_callback, debug
)
if register_err:
raise Exception(register_err)
def get_disco_servers(self) -> str:
- servers, servers_err = GetDiscoServers(self.name)
+ servers, servers_err = self.go_function(lib.GetDiscoServers)
if servers_err:
raise Exception(servers_err)
@@ -221,20 +82,22 @@ class EduVPN(object):
return servers
def get_disco_organizations(self) -> str:
- organizations, organizations_err = GetDiscoOrganizations(self.name)
+ organizations, organizations_err = self.go_function(lib.GetDiscoOrganizations)
if organizations_err:
raise Exception(organizations_err)
return organizations
- def get_config(self, url: str, is_secure_internet: bool = False, force_tcp: bool = False):
+ def get_config(
+ self, url: str, is_secure_internet: bool = False, force_tcp: bool = False
+ ):
# Because it could be the case that a profile callback is started, store a threading event
# In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set
# The event is set in self.set_profile
self.profile_event = threading.Event()
- config, config_type, config_err = GetConnectConfig(
- self.name, url, is_secure_internet, force_tcp
+ config, config_type, config_err = self.go_function(
+ lib.GetConnectConfig, url, is_secure_internet, force_tcp
)
if config_err:
@@ -255,19 +118,19 @@ class EduVPN(object):
return self.get_config(url, True, force_tcp)
def set_connected(self) -> None:
- connect_err = SetConnected(self.name)
+ connect_err = self.go_function(lib.SetConnected)
if connect_err:
raise Exception(connect_err)
def set_disconnected(self) -> None:
- disconnect_err = SetDisconnected(self.name)
+ disconnect_err = self.go_function(lib.SetDisconnected)
if disconnect_err:
raise Exception(disconnect_err)
def get_identifier(self) -> str:
- identifier, identifier_err = GetIdentifier(self.name)
+ identifier, identifier_err = self.go_function(lib.GetIdentifier)
if identifier_err:
raise Exception(identifier_err)
@@ -275,43 +138,22 @@ class EduVPN(object):
return identifier
def set_identifier(self, identifier: str) -> None:
- identifier_err = SetIdentifier(self.name, identifier)
+ identifier_err = self.go_function(lib.SetIdentifier, identifier)
if identifier_err:
raise Exception(identifier_err)
def set_search_server(self) -> None:
- search_err = SetSearchServer(self.name)
+ search_err = self.go_function(lib.SetSearchServer)
if search_err:
raise Exception(search_err)
- def change_class_callbacks(self, cls, add=True) -> None:
- # Loop over method names
- for method_name in dir(cls):
-
- try:
- # Get the method
- method = getattr(cls, method_name)
- except:
- # Unable to get a value, go to the next
- continue
-
- # If it has a callback defined, add it to the events
- method_value = getattr(method, EDUVPN_CALLBACK_PROPERTY, None)
- if method_value:
- state, state_type = method_value
-
- if add:
- self.event.add_event(state, state_type, method)
- else:
- self.event.remove_event(state, state_type, method)
-
def remove_class_callbacks(self, cls) -> None:
- self.change_class_callbacks(cls, add=False)
+ self.event_handler.change_class_callbacks(cls, add=False)
def register_class_callbacks(self, cls) -> None:
- self.change_class_callbacks(cls)
+ self.event_handler.change_class_callbacks(cls)
@property
def event(self) -> EventHandler:
@@ -322,7 +164,7 @@ class EduVPN(object):
def set_profile(self, profile_id: str) -> None:
# Set the profile id
- profile_err = SetProfileID(self.name, profile_id)
+ profile_err = self.go_function(lib.SetProfileID, profile_id)
if profile_err:
raise Exception(profile_err)