diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-26 14:50:22 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-26 15:33:04 +0200 |
| commit | 7e4494256a08f585523e01b1bbc51f41ff4e2b95 (patch) | |
| tree | ccbf873b2bfb11aa22f185e78ce1e2e5eebd094c /wrappers/python/src | |
| parent | 448c51d2142c186f0490b9d51c0d73beb3c76863 (diff) | |
Refactor: Errors into custom export types and expose types
Diffstat (limited to 'wrappers/python/src')
| -rw-r--r-- | wrappers/python/src/__init__.py | 78 | ||||
| -rw-r--r-- | wrappers/python/src/error.py | 15 | ||||
| -rw-r--r-- | wrappers/python/src/main.py | 41 |
3 files changed, 64 insertions, 70 deletions
diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py index 3bafc0e..cb4ba9b 100644 --- a/wrappers/python/src/__init__.py +++ b/wrappers/python/src/__init__.py @@ -1,11 +1,10 @@ from ctypes import * from collections import defaultdict -from enum import Enum import pathlib import platform from typing import Tuple, Optional -import json from typing import List +from .error import WrappedError, ErrorLevel _lib_prefixes = defaultdict( lambda: "lib", @@ -37,10 +36,12 @@ except: lib = cdll.LoadLibrary(str(pathlib.Path(__file__).parent / "lib" / _libfile)) -class ErrorLevel(Enum): - ERR_OTHER = 0 - ERR_INFO = 1 - +class cError(Structure): + _fields_ = [ + ("level", c_int), + ("traceback", c_char_p), + ("cause", c_char_p), + ] class cServerLocations(Structure): _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)] @@ -126,7 +127,11 @@ class cServers(Structure): class DataError(Structure): - _fields_ = [("data", c_void_p), ("error", c_void_p)] + _fields_ = [("data", c_void_p), ("error", POINTER(cError))] + + +class ConfigError(Structure): + _fields_ = [("config", c_char_p), ("config_type", c_char_p), ("error", POINTER(cError))] VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) @@ -149,17 +154,17 @@ lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [ c_char_p, c_char_p, c_int, -], DataError +], ConfigError lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [ c_char_p, c_char_p, c_int, -], DataError +], ConfigError lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [ c_char_p, c_char_p, c_int, -], DataError +], ConfigError lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], None lib.Register.argtypes, lib.Register.restype = [ c_char_p, @@ -195,19 +200,13 @@ lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [ c_void_p ], None lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None +lib.FreeError.argtypes, lib.FreeError.restype = [c_void_p], None lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], DataError -class WrappedError: - def __init__(self, traceback: str, cause: str, level: ErrorLevel): - self.traceback = traceback - self.cause = cause - self.level = level - - def encode_args(args, types): for arg, t in zip(args, types): # c_char_p needs the str to be encoded to bytes @@ -239,37 +238,21 @@ def get_ptr_list_strings( return strings_list return [] - -def get_ptr_error(ptr: c_void_p) -> Optional[WrappedError]: - error_string = get_ptr_string(ptr) - - if not error_string: +def get_error(ptr: c_void_p) -> Optional[WrappedError]: + if not ptr: return None - - error_json = json.loads(error_string) - - if not error_json: - return None - - if "level" not in error_json: - return error_string - level = error_json["level"] - traceback = error_json["traceback"] - cause = error_json["cause"] - return WrappedError(traceback, cause, ErrorLevel(level)) - - -def get_error(ptr: c_void_p) -> str: - error = get_ptr_error(ptr) - if not error: - return "" - - if not isinstance(error, WrappedError): - return error - return error.cause - - -def get_data_error(data_error: DataError, data_conv=get_ptr_string) -> Tuple[str, str]: + err = cast(ptr, POINTER(cError)).contents + wrapped = WrappedError(err.traceback.decode(), err.cause.decode(), ErrorLevel(err.level)) + lib.FreeError(ptr) + return wrapped + +def get_config_error(config_error: ConfigError) -> Tuple[str, str, Optional[WrappedError]]: + config = get_ptr_string(config_error.config) + config_type = get_ptr_string(config_error.config_type) + err = get_error(config_error.error) + return config, config_type, err + +def get_data_error(data_error: DataError, data_conv=get_ptr_string) -> Tuple[str, Optional[WrappedError]]: data = data_conv(data_error.data) error = get_error(data_error.error) return data, error @@ -283,4 +266,5 @@ decode_map = { c_int: get_bool, c_void_p: get_error, DataError: get_data_error, + ConfigError: get_config_error, } diff --git a/wrappers/python/src/error.py b/wrappers/python/src/error.py new file mode 100644 index 0000000..50298bb --- /dev/null +++ b/wrappers/python/src/error.py @@ -0,0 +1,15 @@ +from enum import Enum + +class ErrorLevel(Enum): + ERR_OTHER = 0 + ERR_INFO = 1 + ERR_WARNING = 2 + ERR_FATAL = 3 + +class WrappedError(Exception): + def __init__(self, traceback: str, cause: str, level: ErrorLevel): + super(WrappedError, self).__init__(cause) + self.traceback = traceback + self.cause = cause + self.level = level + diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py index 1ee9dd7..01621ae 100644 --- a/wrappers/python/src/main.py +++ b/wrappers/python/src/main.py @@ -5,7 +5,6 @@ from .discovery import get_disco_organizations, get_disco_servers from .event import EventHandler from .state import State, StateType from .server import get_servers -import json eduvpn_objects = {} @@ -70,7 +69,7 @@ class EduVPN(object): cancel_oauth_err = self.go_function(lib.CancelOAuth) if cancel_oauth_err: - raise Exception(cancel_oauth_err) + raise cancel_oauth_err def deregister(self) -> None: self.go_function(lib.Deregister) @@ -85,7 +84,7 @@ class EduVPN(object): ) if register_err: - raise Exception(register_err) + raise register_err def get_disco_servers(self) -> str: servers, servers_err = self.go_function_custom_decode( @@ -93,7 +92,7 @@ class EduVPN(object): ) if servers_err: - raise Exception(servers_err) + raise servers_err return servers @@ -103,7 +102,7 @@ class EduVPN(object): ) if organizations_err: - raise Exception(organizations_err) + raise organizations_err return organizations @@ -111,19 +110,19 @@ class EduVPN(object): remove_err = self.go_function(lib.RemoveSecureInternet) if remove_err: - raise Exception(remove_err) + raise remove_err def remove_institute_access(self, url: str): remove_err = self.go_function(lib.RemoveInstituteAccess, url) if remove_err: - raise Exception(remove_err) + raise remove_err def remove_custom_server(self, url: str): remove_err = self.go_function(lib.RemoveCustomServer, url) if remove_err: - raise Exception(remove_err) + raise remove_err def get_config(self, url: str, func: callable, force_tcp: bool = False): # Because it could be the case that a profile callback is started, store a threading event @@ -131,17 +130,13 @@ class EduVPN(object): # The event is set in self.set_profile self.profile_event = threading.Event() - config_json, config_err = self.go_function(func, url, force_tcp) + config, config_type, config_err = self.go_function(func, url, force_tcp) self.profile_event = None self.location_event = None if config_err: - raise Exception(config_err) - - config_json_dict = json.loads(config_json) - config = config_json_dict["config"] - config_type = config_json_dict["config_type"] + raise config_err return config, config_type @@ -169,31 +164,31 @@ class EduVPN(object): connect_err = self.go_function(lib.SetConnected) if connect_err: - raise Exception(connect_err) + raise connect_err def set_disconnecting(self) -> None: disconnecting_err = self.go_function(lib.SetDisconnecting) if disconnecting_err: - raise Exception(disconnecting_err) + raise disconnecting_err def set_connecting(self) -> None: connecting_err = self.go_function(lib.SetConnecting) if connecting_err: - raise Exception(connecting_err) + raise connecting_err def set_disconnected(self, cleanup=True) -> None: disconnect_err = self.go_function(lib.SetDisconnected, cleanup) if disconnect_err: - raise Exception(disconnect_err) + raise disconnect_err def set_search_server(self) -> None: search_err = self.go_function(lib.SetSearchServer) if search_err: - raise Exception(search_err) + raise search_err def remove_class_callbacks(self, cls) -> None: self.event_handler.change_class_callbacks(cls, add=False) @@ -218,7 +213,7 @@ class EduVPN(object): self.profile_event.set() if profile_err: - raise Exception(profile_err) + raise profile_err def change_secure_location(self) -> None: # Set the location by country code @@ -226,7 +221,7 @@ class EduVPN(object): location_err = self.go_function(lib.ChangeSecureLocation) if location_err: - raise Exception(location_err) + raise location_err def set_secure_location(self, country_code: str) -> None: # Set the location by country code @@ -238,13 +233,13 @@ class EduVPN(object): self.location_event.set() if location_err: - raise Exception(location_err) + raise location_err def renew_session(self) -> None: renew_err = self.go_function(lib.RenewSession) if renew_err: - raise Exception(renew_err) + raise renew_err def should_renew_button(self) -> bool: return self.go_function(lib.ShouldRenewButton) |
