summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-10-14 15:22:43 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-10-14 17:54:36 +0200
commitd615578b89f1f10d0f057315a58a29c30f1f8693 (patch)
tree42565ffa9f4744f5b4afeb809f8d41d44858043b
parent13d3f167e9d13aef234bcf6f4609d9aa7d61e032 (diff)
Python: Annotate more return types
-rw-r--r--wrappers/python/eduvpn_common/discovery.py6
-rw-r--r--wrappers/python/eduvpn_common/event.py6
-rw-r--r--wrappers/python/eduvpn_common/main.py32
3 files changed, 22 insertions, 22 deletions
diff --git a/wrappers/python/eduvpn_common/discovery.py b/wrappers/python/eduvpn_common/discovery.py
index f86d4fa..be9d079 100644
--- a/wrappers/python/eduvpn_common/discovery.py
+++ b/wrappers/python/eduvpn_common/discovery.py
@@ -70,7 +70,7 @@ def get_disco_organization(ptr) -> Optional[DiscoOrganization]:
return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list)
-def get_disco_server(lib: CDLL, ptr):
+def get_disco_server(lib: CDLL, ptr) -> Optional[DiscoServer]:
if not ptr:
return None
@@ -101,7 +101,7 @@ def get_disco_server(lib: CDLL, ptr):
)
-def get_disco_servers(lib: CDLL, ptr: c_void_p):
+def get_disco_servers(lib: CDLL, ptr: c_void_p) -> Optional[DiscoServers]:
if ptr:
svrs = cast(ptr, POINTER(cDiscoveryServers)).contents
@@ -119,7 +119,7 @@ def get_disco_servers(lib: CDLL, ptr: c_void_p):
return None
-def get_disco_organizations(lib: CDLL, ptr: c_void_p):
+def get_disco_organizations(lib: CDLL, ptr: c_void_p) -> Optional[DiscoOrganizations]:
if ptr:
orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents
organizations = []
diff --git a/wrappers/python/eduvpn_common/event.py b/wrappers/python/eduvpn_common/event.py
index d2ab952..e2bfabb 100644
--- a/wrappers/python/eduvpn_common/event.py
+++ b/wrappers/python/eduvpn_common/event.py
@@ -22,7 +22,7 @@ def class_state_transition(state: int, state_type: StateType) -> Callable:
return wrapper
-def convert_data(lib: CDLL, state: int, data: Any):
+def convert_data(lib: CDLL, state: int, data: Any) -> None:
if not data:
return None
if state is State.NO_SERVER:
@@ -67,7 +67,7 @@ class EventHandler(object):
else:
self.remove_event(state, state_type, method)
- def remove_event(self, state: int, state_type: StateType, func: Callable):
+ def remove_event(self, state: int, state_type: StateType, func: Callable) -> None:
for key, values in self.handlers.copy().items():
if key == (state, state_type):
values.remove(func)
@@ -76,7 +76,7 @@ class EventHandler(object):
else:
self.handlers[key] = values
- def add_event(self, state: int, state_type: StateType, func: Callable):
+ def add_event(self, state: int, state_type: StateType, func: Callable) -> None:
if (state, state_type) not in self.handlers:
self.handlers[(state, state_type)] = []
self.handlers[(state, state_type)].append(func)
diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py
index b581c5e..49618d8 100644
--- a/wrappers/python/eduvpn_common/main.py
+++ b/wrappers/python/eduvpn_common/main.py
@@ -1,11 +1,11 @@
import threading
from ctypes import c_char_p, c_int, c_void_p
-from typing import Any, Callable, Dict, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
-from eduvpn_common.discovery import get_disco_organizations, get_disco_servers
+from eduvpn_common.discovery import DiscoOrganizations, DiscoServers, get_disco_organizations, get_disco_servers
from eduvpn_common.event import EventHandler
from eduvpn_common.loader import initialize_functions, load_lib
-from eduvpn_common.server import get_servers
+from eduvpn_common.server import Server, get_servers
from eduvpn_common.state import State, StateType
from eduvpn_common.types import VPNStateChange, decode_res, encode_args, get_data_error
@@ -39,7 +39,7 @@ class EduVPN(object):
if self.location_event:
self.location_event.wait()
- def go_function(self, func: Any, *args, decode_func: Optional[Callable] = None):
+ def go_function(self, func: Any, *args, decode_func: Optional[Callable] = None) -> Any:
# The functions all have at least one arg type which is the name of the client
args_gen = encode_args(list(args), func.argtypes[1:])
res = func(self.name.encode("utf-8"), *(args_gen))
@@ -73,7 +73,7 @@ class EduVPN(object):
if register_err:
raise register_err
- def get_disco_servers(self) -> str:
+ def get_disco_servers(self) -> Optional[DiscoServers]:
servers, servers_err = self.go_function(
self.lib.GetDiscoServers,
decode_func=lambda lib, x: get_data_error(lib, x, get_disco_servers),
@@ -84,7 +84,7 @@ class EduVPN(object):
return servers
- def get_disco_organizations(self) -> str:
+ def get_disco_organizations(self) -> Optional[DiscoOrganizations]:
organizations, organizations_err = self.go_function(
self.lib.GetDiscoOrganizations,
decode_func=lambda lib, x: get_data_error(lib, x, get_disco_organizations),
@@ -95,44 +95,44 @@ class EduVPN(object):
return organizations
- def remove_secure_internet(self):
+ def remove_secure_internet(self) -> None:
remove_err = self.go_function(self.lib.RemoveSecureInternet)
if remove_err:
raise remove_err
- def add_institute_access(self, url: str):
+ def add_institute_access(self, url: str) -> None:
add_err = self.go_function(self.lib.AddInstituteAccess, url)
if add_err:
raise add_err
- def add_secure_internet_home(self, org_id: str):
+ def add_secure_internet_home(self, org_id: str) -> None:
self.location_event = threading.Event()
add_err = self.go_function(self.lib.AddSecureInternetHomeServer, org_id)
if add_err:
raise add_err
- def add_custom_server(self, url: str):
+ def add_custom_server(self, url: str) -> None:
add_err = self.go_function(self.lib.AddCustomServer, url)
if add_err:
raise add_err
- def remove_institute_access(self, url: str):
+ def remove_institute_access(self, url: str) -> None:
remove_err = self.go_function(self.lib.RemoveInstituteAccess, url)
if remove_err:
raise remove_err
- def remove_custom_server(self, url: str):
+ def remove_custom_server(self, url: str) -> None:
remove_err = self.go_function(self.lib.RemoveCustomServer, url)
if remove_err:
raise remove_err
- def get_config(self, url: str, func: Any, prefer_tcp: bool = False):
+ def get_config(self, url: str, func: Any, prefer_tcp: bool = False) -> Tuple[str, str]:
# Because it could be the case that a profile callback is started, store a threading event
# In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set
# The event is set in self.set_profile
@@ -254,7 +254,7 @@ class EduVPN(object):
def in_fsm_state(self, state_id: State) -> bool:
return self.go_function(self.lib.InFSMState, state_id)
- def get_saved_servers(self):
+ def get_saved_servers(self) -> Optional[List[Server]]:
servers, servers_err = self.go_function(
self.lib.GetSavedServers,
decode_func=lambda lib, x: get_data_error(lib, x, get_servers),
@@ -270,7 +270,7 @@ eduvpn_objects: Dict[str, EduVPN] = {}
@VPNStateChange
-def state_callback(name: bytes, old_state: int, new_state: int, data: Any):
+def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> None:
name_decoded = name.decode()
if name_decoded not in eduvpn_objects:
return
@@ -285,6 +285,6 @@ def add_as_global_object(eduvpn: EduVPN) -> bool:
return False
-def remove_as_global_object(eduvpn: EduVPN):
+def remove_as_global_object(eduvpn: EduVPN) -> None:
global eduvpn_objects
eduvpn_objects.pop(eduvpn.name, None)