summaryrefslogtreecommitdiff
path: root/wrappers/python/eduvpn_common/server.py
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-10-14 15:06:36 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-10-14 15:16:02 +0200
commitef224e5dad25debf4526f2fb017bd771a60448dd (patch)
tree63882cce7a4610f0d8f9b94509633a764cbbddc3 /wrappers/python/eduvpn_common/server.py
parente26dd96631022974223f0f4fba48dc95e036d63d (diff)
Python: Proper type annotations
Diffstat (limited to 'wrappers/python/eduvpn_common/server.py')
-rw-r--r--wrappers/python/eduvpn_common/server.py52
1 files changed, 27 insertions, 25 deletions
diff --git a/wrappers/python/eduvpn_common/server.py b/wrappers/python/eduvpn_common/server.py
index 71b6487..01b5204 100644
--- a/wrappers/python/eduvpn_common/server.py
+++ b/wrappers/python/eduvpn_common/server.py
@@ -1,10 +1,11 @@
+from typing import List, Optional, Type
from eduvpn_common.types import cServer, cServers, cServerLocations, cServerProfiles
-from ctypes import cast, POINTER
+from ctypes import c_void_p, cast, POINTER, CDLL
from datetime import datetime
class Profile:
- def __init__(self, identifier, display_name, default_gateway: bool):
+ def __init__(self, identifier: str, display_name: str, default_gateway: bool):
self.identifier = identifier
self.display_name = display_name
self.default_gateway = default_gateway
@@ -14,19 +15,19 @@ class Profile:
class Profiles:
- def __init__(self, profiles, current):
+ def __init__(self, profiles: List[Profile], current: int):
self.profiles = profiles
self.current_index = current
@property
- def current(self):
+ def current(self) -> Optional[Profile]:
if self.current_index < len(self.profiles):
return self.profiles[self.current_index]
return None
class Server:
- def __init__(self, url, display_name, profiles=None, expire_time=0):
+ def __init__(self, url: str, display_name: str, profiles: Optional[Profiles] = None, expire_time: int = 0):
self.url = url
self.display_name = display_name
self.profiles = profiles
@@ -36,29 +37,28 @@ class Server:
return self.display_name
@property
- def category(self):
+ def category(self) -> str:
return "Custom Server"
class InstituteServer(Server):
- def __init__(self, url, display_name, support_contact, profiles, expire_time):
+ def __init__(self, url: str, display_name: str, support_contact: List[str], profiles: Profiles, expire_time: int):
super().__init__(url, display_name, profiles, expire_time)
self.support_contact = support_contact
@property
- def category(self):
+ def category(self) -> str:
return "Institute Access Server"
-
class SecureInternetServer(Server):
def __init__(
self,
- org_id,
- display_name,
- support_contact,
- profiles,
- expire_time,
- country_code,
+ org_id: str,
+ display_name: str,
+ support_contact: List[str],
+ profiles: Profiles,
+ expire_time: int,
+ country_code: str,
):
super().__init__(org_id, display_name, profiles, expire_time)
self.org_id = org_id
@@ -66,11 +66,11 @@ class SecureInternetServer(Server):
self.country_code = country_code
@property
- def category(self):
+ def category(self) -> str:
return "Secure Internet Server"
-def get_type_for_str(type_str: str):
+def get_type_for_str(type_str: str) -> Type[Server]:
if type_str == "secure_internet":
return SecureInternetServer
if type_str == "custom_server":
@@ -78,14 +78,14 @@ def get_type_for_str(type_str: str):
return InstituteServer
-def get_profiles(ptr):
+def get_profiles(ptr) -> Optional[Profiles]:
if not ptr:
- return []
+ return None
profiles = []
_profiles = ptr.contents
current_profile = _profiles.current
if not _profiles.profiles:
- return []
+ return None
for i in range(_profiles.total_profiles):
if not _profiles.profiles[i]:
continue
@@ -100,7 +100,7 @@ def get_profiles(ptr):
return Profiles(profiles, current_profile)
-def get_server(ptr, _type=None):
+def get_server(ptr, _type=None) -> Optional[Server]:
if not ptr:
return None
@@ -116,6 +116,8 @@ def get_server(ptr, _type=None):
for i in range(current_server.total_support_contact):
support_contact.append(current_server.support_contact[i].decode("utf-8"))
profiles = get_profiles(current_server.profiles)
+ if profiles is None:
+ return None
if _type is SecureInternetServer:
return SecureInternetServer(
identifier,
@@ -136,19 +138,19 @@ def get_server(ptr, _type=None):
return Server(identifier, display_name, profiles, current_server.expire_time)
-def get_transition_server(lib, ptr):
+def get_transition_server(lib: CDLL, ptr: c_void_p) -> Optional[Server]:
server = get_server(cast(ptr, POINTER(cServer)))
lib.FreeServer(ptr)
return server
-def get_transition_profiles(lib, ptr):
+def get_transition_profiles(lib: CDLL, ptr: c_void_p) -> Optional[Profiles]:
profiles = get_profiles(cast(ptr, POINTER(cServerProfiles)))
lib.FreeProfiles(ptr)
return profiles
-def get_servers(lib, ptr):
+def get_servers(lib: CDLL, ptr: c_void_p) -> Optional[List[Server]]:
if ptr:
returned = []
servers = cast(ptr, POINTER(cServers)).contents
@@ -175,7 +177,7 @@ def get_servers(lib, ptr):
return None
-def get_locations(lib, ptr):
+def get_locations(lib: CDLL, ptr: c_void_p) -> Optional[List[str]]:
if ptr:
locations = cast(ptr, POINTER(cServerLocations)).contents
location_list = []