summaryrefslogtreecommitdiff
path: root/wrappers
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-26 14:50:22 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-26 15:33:04 +0200
commit7e4494256a08f585523e01b1bbc51f41ff4e2b95 (patch)
treeccbf873b2bfb11aa22f185e78ce1e2e5eebd094c /wrappers
parent448c51d2142c186f0490b9d51c0d73beb3c76863 (diff)
Refactor: Errors into custom export types and expose types
Diffstat (limited to 'wrappers')
-rw-r--r--wrappers/python/src/__init__.py78
-rw-r--r--wrappers/python/src/error.py15
-rw-r--r--wrappers/python/src/main.py41
3 files changed, 64 insertions, 70 deletions
diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py
index 3bafc0e..cb4ba9b 100644
--- a/wrappers/python/src/__init__.py
+++ b/wrappers/python/src/__init__.py
@@ -1,11 +1,10 @@
from ctypes import *
from collections import defaultdict
-from enum import Enum
import pathlib
import platform
from typing import Tuple, Optional
-import json
from typing import List
+from .error import WrappedError, ErrorLevel
_lib_prefixes = defaultdict(
lambda: "lib",
@@ -37,10 +36,12 @@ except:
lib = cdll.LoadLibrary(str(pathlib.Path(__file__).parent / "lib" / _libfile))
-class ErrorLevel(Enum):
- ERR_OTHER = 0
- ERR_INFO = 1
-
+class cError(Structure):
+ _fields_ = [
+ ("level", c_int),
+ ("traceback", c_char_p),
+ ("cause", c_char_p),
+ ]
class cServerLocations(Structure):
_fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)]
@@ -126,7 +127,11 @@ class cServers(Structure):
class DataError(Structure):
- _fields_ = [("data", c_void_p), ("error", c_void_p)]
+ _fields_ = [("data", c_void_p), ("error", POINTER(cError))]
+
+
+class ConfigError(Structure):
+ _fields_ = [("config", c_char_p), ("config_type", c_char_p), ("error", POINTER(cError))]
VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p)
@@ -149,17 +154,17 @@ lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [
c_char_p,
c_char_p,
c_int,
-], DataError
+], ConfigError
lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [
c_char_p,
c_char_p,
c_int,
-], DataError
+], ConfigError
lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [
c_char_p,
c_char_p,
c_int,
-], DataError
+], ConfigError
lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None
lib.Register.argtypes, lib.Register.restype = [
c_char_p,
@@ -195,19 +200,13 @@ lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [
c_void_p
], None
lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None
+lib.FreeError.argtypes, lib.FreeError.restype = [c_void_p], None
lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None
lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None
lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int
lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], DataError
-class WrappedError:
- def __init__(self, traceback: str, cause: str, level: ErrorLevel):
- self.traceback = traceback
- self.cause = cause
- self.level = level
-
-
def encode_args(args, types):
for arg, t in zip(args, types):
# c_char_p needs the str to be encoded to bytes
@@ -239,37 +238,21 @@ def get_ptr_list_strings(
return strings_list
return []
-
-def get_ptr_error(ptr: c_void_p) -> Optional[WrappedError]:
- error_string = get_ptr_string(ptr)
-
- if not error_string:
+def get_error(ptr: c_void_p) -> Optional[WrappedError]:
+ if not ptr:
return None
-
- error_json = json.loads(error_string)
-
- if not error_json:
- return None
-
- if "level" not in error_json:
- return error_string
- level = error_json["level"]
- traceback = error_json["traceback"]
- cause = error_json["cause"]
- return WrappedError(traceback, cause, ErrorLevel(level))
-
-
-def get_error(ptr: c_void_p) -> str:
- error = get_ptr_error(ptr)
- if not error:
- return ""
-
- if not isinstance(error, WrappedError):
- return error
- return error.cause
-
-
-def get_data_error(data_error: DataError, data_conv=get_ptr_string) -> Tuple[str, str]:
+ err = cast(ptr, POINTER(cError)).contents
+ wrapped = WrappedError(err.traceback.decode(), err.cause.decode(), ErrorLevel(err.level))
+ lib.FreeError(ptr)
+ return wrapped
+
+def get_config_error(config_error: ConfigError) -> Tuple[str, str, Optional[WrappedError]]:
+ config = get_ptr_string(config_error.config)
+ config_type = get_ptr_string(config_error.config_type)
+ err = get_error(config_error.error)
+ return config, config_type, err
+
+def get_data_error(data_error: DataError, data_conv=get_ptr_string) -> Tuple[str, Optional[WrappedError]]:
data = data_conv(data_error.data)
error = get_error(data_error.error)
return data, error
@@ -283,4 +266,5 @@ decode_map = {
c_int: get_bool,
c_void_p: get_error,
DataError: get_data_error,
+ ConfigError: get_config_error,
}
diff --git a/wrappers/python/src/error.py b/wrappers/python/src/error.py
new file mode 100644
index 0000000..50298bb
--- /dev/null
+++ b/wrappers/python/src/error.py
@@ -0,0 +1,15 @@
+from enum import Enum
+
+class ErrorLevel(Enum):
+ ERR_OTHER = 0
+ ERR_INFO = 1
+ ERR_WARNING = 2
+ ERR_FATAL = 3
+
+class WrappedError(Exception):
+ def __init__(self, traceback: str, cause: str, level: ErrorLevel):
+ super(WrappedError, self).__init__(cause)
+ self.traceback = traceback
+ self.cause = cause
+ self.level = level
+
diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py
index 1ee9dd7..01621ae 100644
--- a/wrappers/python/src/main.py
+++ b/wrappers/python/src/main.py
@@ -5,7 +5,6 @@ from .discovery import get_disco_organizations, get_disco_servers
from .event import EventHandler
from .state import State, StateType
from .server import get_servers
-import json
eduvpn_objects = {}
@@ -70,7 +69,7 @@ class EduVPN(object):
cancel_oauth_err = self.go_function(lib.CancelOAuth)
if cancel_oauth_err:
- raise Exception(cancel_oauth_err)
+ raise cancel_oauth_err
def deregister(self) -> None:
self.go_function(lib.Deregister)
@@ -85,7 +84,7 @@ class EduVPN(object):
)
if register_err:
- raise Exception(register_err)
+ raise register_err
def get_disco_servers(self) -> str:
servers, servers_err = self.go_function_custom_decode(
@@ -93,7 +92,7 @@ class EduVPN(object):
)
if servers_err:
- raise Exception(servers_err)
+ raise servers_err
return servers
@@ -103,7 +102,7 @@ class EduVPN(object):
)
if organizations_err:
- raise Exception(organizations_err)
+ raise organizations_err
return organizations
@@ -111,19 +110,19 @@ class EduVPN(object):
remove_err = self.go_function(lib.RemoveSecureInternet)
if remove_err:
- raise Exception(remove_err)
+ raise remove_err
def remove_institute_access(self, url: str):
remove_err = self.go_function(lib.RemoveInstituteAccess, url)
if remove_err:
- raise Exception(remove_err)
+ raise remove_err
def remove_custom_server(self, url: str):
remove_err = self.go_function(lib.RemoveCustomServer, url)
if remove_err:
- raise Exception(remove_err)
+ raise remove_err
def get_config(self, url: str, func: callable, force_tcp: bool = False):
# Because it could be the case that a profile callback is started, store a threading event
@@ -131,17 +130,13 @@ class EduVPN(object):
# The event is set in self.set_profile
self.profile_event = threading.Event()
- config_json, config_err = self.go_function(func, url, force_tcp)
+ config, config_type, config_err = self.go_function(func, url, force_tcp)
self.profile_event = None
self.location_event = None
if config_err:
- raise Exception(config_err)
-
- config_json_dict = json.loads(config_json)
- config = config_json_dict["config"]
- config_type = config_json_dict["config_type"]
+ raise config_err
return config, config_type
@@ -169,31 +164,31 @@ class EduVPN(object):
connect_err = self.go_function(lib.SetConnected)
if connect_err:
- raise Exception(connect_err)
+ raise connect_err
def set_disconnecting(self) -> None:
disconnecting_err = self.go_function(lib.SetDisconnecting)
if disconnecting_err:
- raise Exception(disconnecting_err)
+ raise disconnecting_err
def set_connecting(self) -> None:
connecting_err = self.go_function(lib.SetConnecting)
if connecting_err:
- raise Exception(connecting_err)
+ raise connecting_err
def set_disconnected(self, cleanup=True) -> None:
disconnect_err = self.go_function(lib.SetDisconnected, cleanup)
if disconnect_err:
- raise Exception(disconnect_err)
+ raise disconnect_err
def set_search_server(self) -> None:
search_err = self.go_function(lib.SetSearchServer)
if search_err:
- raise Exception(search_err)
+ raise search_err
def remove_class_callbacks(self, cls) -> None:
self.event_handler.change_class_callbacks(cls, add=False)
@@ -218,7 +213,7 @@ class EduVPN(object):
self.profile_event.set()
if profile_err:
- raise Exception(profile_err)
+ raise profile_err
def change_secure_location(self) -> None:
# Set the location by country code
@@ -226,7 +221,7 @@ class EduVPN(object):
location_err = self.go_function(lib.ChangeSecureLocation)
if location_err:
- raise Exception(location_err)
+ raise location_err
def set_secure_location(self, country_code: str) -> None:
# Set the location by country code
@@ -238,13 +233,13 @@ class EduVPN(object):
self.location_event.set()
if location_err:
- raise Exception(location_err)
+ raise location_err
def renew_session(self) -> None:
renew_err = self.go_function(lib.RenewSession)
if renew_err:
- raise Exception(renew_err)
+ raise renew_err
def should_renew_button(self) -> bool:
return self.go_function(lib.ShouldRenewButton)