From da83f54606c9c1d2786d87074ee17ed972d2e1b2 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Wed, 14 Sep 2022 13:56:49 +0200 Subject: Refactor: Return without json --- exports/Makefile | 12 +- exports/c/common.c | 6 + exports/c/common.h | 3 + exports/c/disco.h | 15 ++ exports/c/servers.h | 43 ++++++ exports/common.mk | 2 + exports/disco.go | 97 +++++++++++++ exports/exports.go | 66 ++++----- exports/servers.go | 273 ++++++++++++++++++++++++++++++++++++ fsm.go | 32 +++-- go.mod | 6 +- go.sum | 1 - internal/discovery/discovery.go | 18 ++- internal/fsm/fsm.go | 11 +- internal/oauth/oauth.go | 12 +- internal/server/custom.go | 1 - internal/types/server.go | 28 ++-- internal/util/util.go | 49 +++++++ internal/util/util_test.go | 52 ++++++- state.go | 296 +++++++++++++++++++++++++++++++-------- wrappers/python/main.py | 57 ++++---- wrappers/python/src/__init__.py | 68 ++++++++- wrappers/python/src/discovery.py | 43 ++++++ wrappers/python/src/event.py | 21 ++- wrappers/python/src/main.py | 25 +++- wrappers/python/src/server.py | 158 +++++++++++++++++++++ wrappers/python/tests.py | 2 +- 27 files changed, 1225 insertions(+), 172 deletions(-) create mode 100644 exports/c/common.c create mode 100644 exports/c/common.h create mode 100644 exports/c/disco.h create mode 100644 exports/c/servers.h create mode 100644 exports/disco.go create mode 100644 exports/servers.go create mode 100644 wrappers/python/src/discovery.py create mode 100644 wrappers/python/src/server.py diff --git a/exports/Makefile b/exports/Makefile index b833228..46d17a9 100644 --- a/exports/Makefile +++ b/exports/Makefile @@ -2,6 +2,8 @@ include common.mk +CLIBPATH=./c + ifeq ($(LIB_SUFFIX),.so) # Add SONAME as cgo does not currently do this. Mostly for Android, see https://stackoverflow.com/a/48291044 export override CGO_LDFLAGS += -Wl,-soname,$(LIB_FILE) @@ -13,12 +15,18 @@ ifdef COPY_LIB_TO install $< -Dt $(COPY_LIB_TO) endif +${CLIBPATH}/libcommon$(LIB_SUFFIX): ${CLIBPATH}/common.c + $(CC) -c -Wall -Werror -fpic -o ${CLIBPATH}/common.o ${CLIBPATH}/common.c + $(CC) -shared -o $@ ${CLIBPATH}/common.o + # Build shared library and remove lib prefix (if any) from header name # GOOS and GOARCH envvars are set by common.mk # This extra target prevents unnecessary rebuild -lib/$(GOOS)/$(GOARCH)/$(LIB_FILE): exports.go .. - CGO_ENABLED=1 go build -o $@ -buildmode=c-shared $< +lib/$(GOOS)/$(GOARCH)/$(LIB_FILE): ${CLIBPATH}/libcommon$(LIB_SUFFIX) exports.go servers.go .. + CGO_ENABLED=1 go build -o $@ -buildmode=c-shared . mv lib/$(GOOS)/$(GOARCH)/$(LIB_PREFIX)$(LIB_NAME).h lib/$(GOOS)/$(GOARCH)/$(LIB_NAME).h || true # Normalize header name clean: rm -rf ../exports/lib/* + rm -rf ${CLIBPATH}/common.o + rm -rf ${CLIBPATH}/libcommon.so diff --git a/exports/c/common.c b/exports/c/common.c new file mode 100644 index 0000000..425a459 --- /dev/null +++ b/exports/c/common.c @@ -0,0 +1,6 @@ +#include "common.h" + +void call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data) +{ + callback(name, oldstate, newstate, data); +} diff --git a/exports/c/common.h b/exports/c/common.h new file mode 100644 index 0000000..068ad4c --- /dev/null +++ b/exports/c/common.h @@ -0,0 +1,3 @@ +typedef void (*PythonCB)(const char* name, int oldstate, int newstate, void* data); + +void call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data); diff --git a/exports/c/disco.h b/exports/c/disco.h new file mode 100644 index 0000000..41d59fa --- /dev/null +++ b/exports/c/disco.h @@ -0,0 +1,15 @@ +// for size_t +#include + +typedef struct discoveryOrganization { + const char* display_name; + const char* org_id; + const char* secure_internet_home; + const char* keyword_list; +} discoveryOrganization; + +typedef struct discoveryOrganizations { + unsigned long long int version; + discoveryOrganization** organizations; + size_t total_organizations; +} discoveryOrganizations; diff --git a/exports/c/servers.h b/exports/c/servers.h new file mode 100644 index 0000000..39e52a2 --- /dev/null +++ b/exports/c/servers.h @@ -0,0 +1,43 @@ +// for size_t +#include + +// The struct for a single server profile +typedef struct serverProfile { + const char* id; + const char* display_name; + //const char* proto_list; + int default_gateway; +} serverProfile; + +// The struct for all server profiles +typedef struct serverProfiles { + int current; + serverProfile** profiles; + size_t total_profiles; +} serverProfiles; + +// The struct for server locations +typedef struct serverLocations { + const char** locations; + size_t total_locations; +} serverLocations; + +// The struct for a single server +typedef struct server { + const char* identifier; + const char* display_name; + const char* country_code; + const char** support_contact; + size_t total_support_contact; + serverProfiles* profiles; + unsigned long long int expire_time; +} server; + +// The struct for all servers +typedef struct servers { + server** custom_servers; + size_t total_custom; + server** institute_servers; + size_t total_institute; + server* secure_internet_server; +} servers; diff --git a/exports/common.mk b/exports/common.mk index c1e2a7e..211d460 100644 --- a/exports/common.mk +++ b/exports/common.mk @@ -7,6 +7,8 @@ ifndef GOARCH export GOARCH := $(shell go env GOHOSTARCH) endif +CC = gcc + ifeq (windows,$(GOOS)) LIB_PREFIX ?= LIB_SUFFIX ?= .dll diff --git a/exports/disco.go b/exports/disco.go new file mode 100644 index 0000000..9ee2af9 --- /dev/null +++ b/exports/disco.go @@ -0,0 +1,97 @@ +package main + +/* +// for free +#include +#include "c/disco.h" +*/ +import "C" + +import ( + "unsafe" + + "github.com/jwijenbergh/eduvpn-common" + "github.com/jwijenbergh/eduvpn-common/internal/types" +) + +func getCPtrDiscoOrganization( + state *eduvpn.VPNState, + organization *types.DiscoveryOrganization, +) *C.discoveryOrganization { + returnedStruct := (*C.discoveryOrganization)( + C.malloc(C.size_t(unsafe.Sizeof(C.discoveryOrganization{}))), + ) + returnedStruct.display_name = C.CString(state.GetTranslated(organization.DisplayName)) + returnedStruct.org_id = C.CString(organization.OrgId) + returnedStruct.secure_internet_home = C.CString(organization.SecureInternetHome) + returnedStruct.keyword_list = C.CString(state.GetTranslated(organization.KeywordList)) + return returnedStruct +} + +func getCPtrDiscoOrganizations( + state *eduvpn.VPNState, + organizations *types.DiscoveryOrganizations, +) (C.size_t, **C.discoveryOrganization) { + totalOrganizations := C.size_t(len(organizations.List)) + var organizationsPtr **C.discoveryOrganization + if totalOrganizations > 0 { + organizationsPtr = (**C.discoveryOrganization)( + C.malloc(totalOrganizations * C.size_t(unsafe.Sizeof(uintptr(0)))), + ) + cOrganizations := (*[1<<30 - 1]*C.discoveryOrganization)(unsafe.Pointer(organizationsPtr))[:totalOrganizations:totalOrganizations] + index := 0 + for _, organization := range organizations.List { + cOrganization := getCPtrDiscoOrganization(state, &organization) + cOrganizations[index] = cOrganization + index += 1 + } + } + return totalOrganizations, organizationsPtr +} + +func freeDiscoOrganization(cOrganization *C.discoveryOrganization) { + C.free(unsafe.Pointer(cOrganization.display_name)) + C.free(unsafe.Pointer(cOrganization.org_id)) + C.free(unsafe.Pointer(cOrganization.secure_internet_home)) + C.free(unsafe.Pointer(cOrganization.keyword_list)) + C.free(unsafe.Pointer(cOrganization)) +} + +//export FreeDiscoOrganizations +func FreeDiscoOrganizations(cOrganizations *C.discoveryOrganizations) { + if cOrganizations.total_organizations > 0 { + organizations := (*[1<<30 - 1]*C.discoveryOrganization)(unsafe.Pointer(cOrganizations.organizations))[:cOrganizations.total_organizations:cOrganizations.total_organizations] + for i := C.size_t(0); i < cOrganizations.total_organizations; i++ { + freeDiscoOrganization(organizations[i]) + } + C.free(unsafe.Pointer(cOrganizations.organizations)) + } + C.free(unsafe.Pointer(cOrganizations)) +} + +//export GetDiscoOrganizations +func GetDiscoOrganizations(name *C.char) *C.discoveryOrganizations { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + // TODO + if stateErr != nil { + panic(stateErr) + } + organizations, organizationsErr := state.GetDiscoOrganizations() + // TODO + if organizationsErr != nil { + panic(organizationsErr) + } + + returnedStruct := (*C.discoveryOrganizations)( + C.malloc(C.size_t(unsafe.Sizeof(C.discoveryOrganizations{}))), + ) + + returnedStruct.version = C.ulonglong(organizations.Version) + returnedStruct.total_organizations, returnedStruct.organizations = getCPtrDiscoOrganizations( + state, + organizations, + ) + + return returnedStruct +} diff --git a/exports/exports.go b/exports/exports.go index b4eb909..797a192 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -1,15 +1,13 @@ package main /* -#include - -typedef void (*PythonCB)(const char* name, int oldstate, int newstate, const char* data); +#cgo CFLAGS: -I${SRCDIR}/c +#cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/c +#cgo LDFLAGS: -L${SRCDIR}/c +#cgo LDFLAGS: -lcommon -__attribute__((weak)) -void call_callback(PythonCB callback, const char *name, int oldstate, int newstate, const char* data) -{ - callback(name, oldstate, newstate, data); -} +#include +#include "c/common.h" */ import "C" @@ -26,7 +24,28 @@ var P_StateCallbacks map[string]C.PythonCB var VPNStates map[string]*eduvpn.VPNState +func GetStateData( + state *eduvpn.VPNState, + stateID eduvpn.FSMStateID, + data interface{}, +) unsafe.Pointer { + switch stateID { + case eduvpn.STATE_NO_SERVER: + return (unsafe.Pointer)(getTransitionDataServers(state, data)) + case eduvpn.STATE_OAUTH_STARTED: + if converted, ok := data.(string); ok { + return (unsafe.Pointer)(C.CString(converted)) + } + case eduvpn.STATE_ASK_LOCATION: + return (unsafe.Pointer)(getTransitionSecureLocations(data)) + default: + return nil + } + return nil +} + func StateCallback( + state *eduvpn.VPNState, name string, old_state eduvpn.FSMStateID, new_state eduvpn.FSMStateID, @@ -39,18 +58,10 @@ func StateCallback( name_c := C.CString(name) oldState_c := C.int(old_state) newState_c := C.int(new_state) - data_json, jsonErr := json.Marshal(data) - var dataJsonString string - if jsonErr != nil { - // TODO: How to handle error further? Log? - dataJsonString = "{}" - } else { - dataJsonString = string(data_json) - } - data_c := C.CString(dataJsonString) + data_c := GetStateData(state, new_state, data) C.call_callback(P_StateCallback, name_c, oldState_c, newState_c, data_c) C.free(unsafe.Pointer(name_c)) - C.free(unsafe.Pointer(data_c)) + // data_c gets freed by the wrapper } func GetVPNState(name string) (*eduvpn.VPNState, error) { @@ -87,7 +98,7 @@ func Register( nameStr, C.GoString(config_directory), func(old eduvpn.FSMStateID, new eduvpn.FSMStateID, data interface{}) { - StateCallback(nameStr, old, new, data) + StateCallback(state, nameStr, old, new, data) }, debug != 0, ) @@ -150,7 +161,7 @@ func getConfigJSON(config string, configType string) *C.char { } //export RemoveSecureInternet -func RemoveSecureInternet(name *C.char) (*C.char) { +func RemoveSecureInternet(name *C.char) *C.char { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { @@ -161,7 +172,7 @@ func RemoveSecureInternet(name *C.char) (*C.char) { } //export RemoveInstituteAccess -func RemoveInstituteAccess(name *C.char, url *C.char) (*C.char) { +func RemoveInstituteAccess(name *C.char, url *C.char) *C.char { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { @@ -172,7 +183,7 @@ func RemoveInstituteAccess(name *C.char, url *C.char) (*C.char) { } //export RemoveCustomServer -func RemoveCustomServer(name *C.char, url *C.char) (*C.char) { +func RemoveCustomServer(name *C.char, url *C.char) *C.char { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { @@ -218,17 +229,6 @@ func GetConfigCustomServer(name *C.char, url *C.char, forceTCP C.int) (*C.char, return getConfigJSON(config, configType), C.CString(ErrorToString(configErr)) } -//export GetDiscoOrganizations -func GetDiscoOrganizations(name *C.char) (*C.char, *C.char) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return nil, C.CString(ErrorToString(stateErr)) - } - organizations, organizationsErr := state.GetDiscoOrganizations() - return C.CString(organizations), C.CString(ErrorToString(organizationsErr)) -} - //export GetDiscoServers func GetDiscoServers(name *C.char) (*C.char, *C.char) { nameStr := C.GoString(name) diff --git a/exports/servers.go b/exports/servers.go new file mode 100644 index 0000000..f92c08e --- /dev/null +++ b/exports/servers.go @@ -0,0 +1,273 @@ +package main + +/* +// for free +#include +#include "c/servers.h" +*/ +import "C" + +import ( + "unsafe" + + "github.com/jwijenbergh/eduvpn-common" + "github.com/jwijenbergh/eduvpn-common/internal/server" +) + +// Get the pointer to the C struct for the profile +// We allocate the struct, the profile ID and the display name +func getCPtrProfile(profile *server.ServerProfile) *C.serverProfile { + // Allocate the struct using malloc and the size of the struct + cProfile := (*C.serverProfile)(C.malloc(C.size_t(unsafe.Sizeof(C.serverProfile{})))) + cProfile.id = C.CString(profile.ID) + cProfile.display_name = C.CString(profile.DisplayName) + if profile.DefaultGateway { + cProfile.default_gateway = C.int(1) + } else { + cProfile.default_gateway = C.int(0) + } + + return cProfile +} + +// Get the pointer to the C struct for the profiles +// We allocate the struct and the struct inside it for the profiles +func getCPtrProfiles(serverProfiles *server.ServerProfileInfo) *C.serverProfiles { + goProfiles := serverProfiles.Info.ProfileList + // Allocate the profles struct using malloc and the size of a pointer + cProfiles := (*C.serverProfiles)(C.malloc(C.size_t(uintptr(0)))) + totalProfiles := C.size_t(len(goProfiles)) + // Defaults if we have no profiles + cProfiles.current = C.int(0) + cProfiles.profiles = nil + cProfiles.total_profiles = totalProfiles + // If we have profiles (which we should), we allocate the struct with malloc and the size of a pointer + // We then fill the struct by converting it to a go slice and get a C pointer for each profile + if totalProfiles > 0 { + profilesPtr := C.malloc(totalProfiles * C.size_t(unsafe.Sizeof(uintptr(0)))) + profiles := (*[1<<30 - 1]*C.serverProfile)(unsafe.Pointer(profilesPtr))[:totalProfiles:totalProfiles] + index := 0 + for _, profile := range goProfiles { + profiles[index] = getCPtrProfile(&profile) + index += 1 + } + // TODO: DO CURRENT PROFILE + cProfiles.profiles = (**C.serverProfile)(profilesPtr) + } + return cProfiles +} + +// Free the profiles by looping through them if there are any +// Also free the pointer itself +func freeCProfiles(profiles *C.serverProfiles) { + // We should only free the profiles if we have them (which we should) + if profiles.total_profiles > 0 { + // Convert it to a go slice + profilesSlice := (*[1<<30 - 1]*C.serverProfile)(unsafe.Pointer(profiles.profiles))[:profiles.total_profiles:profiles.total_profiles] + // Loop through the pointers and free th allocated strings and the struct itself + for i := C.size_t(0); i < profiles.total_profiles; i++ { + C.free(unsafe.Pointer(profilesSlice[i].id)) + C.free(unsafe.Pointer(profilesSlice[i].display_name)) + C.free(unsafe.Pointer(profilesSlice[i])) + } + // Free the inner profiles struct + C.free(unsafe.Pointer(profiles.profiles)) + } + // Free the profiles struct itself + C.free(unsafe.Pointer(profiles)) +} + +// Get a list of strings with a size as a c structure +// Returns the size in size_t and the list of strings as a double pointer char +func getCPtrListStrings(allStrings []string) (C.size_t, **C.char) { + // Get the total strings in size_t + totalStrings := C.size_t(len(allStrings)) + + // If we have strings + // Allocate memory for the strings array + if totalStrings > 0 { + stringsPtr := C.malloc(totalStrings * C.size_t(unsafe.Sizeof(uintptr(0)))) + // Go slice conversion + cStrings := (*[1<<30 - 1]*C.char)(unsafe.Pointer(stringsPtr))[:totalStrings:totalStrings] + + // Loop through and allocate the string for each contact + for index, string := range allStrings { + cStrings[index] = C.CString(string) + } + return totalStrings, (**C.char)(stringsPtr) + } + + // No strings then the length is zero and the char array is nil + return C.size_t(0), nil +} + +// Function for freeing an array/list of strings +// It takes the strings as a pointer to a string and the total strings in size_t +func freeCListStrings(allStrings **C.char, totalStrings C.size_t) { + // If we have strings we should free them + // By converting to a Go slice, and freeing them ony by one + // At last free the pointer itself + if totalStrings > 0 { + stringsSlice := (*[1<<30 - 1]*C.char)(unsafe.Pointer(allStrings))[:totalStrings:totalStrings] + for i := C.size_t(0); i < totalStrings; i++ { + C.free(unsafe.Pointer(stringsSlice[i])) + } + C.free(unsafe.Pointer(allStrings)) + } +} + +// Function for getting the server, +// It gets the main state as a pointer as we need to convert some string maps to localized strings +// It gets the base information for a server as well +func getServer(state *eduvpn.VPNState, base *eduvpn.VPNServerBase) *C.server { + // Allocation using malloc and the size of the struct + server := (*C.server)(C.malloc(C.size_t(unsafe.Sizeof(C.server{})))) + // String allocation and translate the display name + server.identifier = C.CString(base.URL) + server.display_name = C.CString(state.GetTranslated(base.DisplayName)) + // Call the helper to get the list of support contacts + server.total_support_contact, server.support_contact = getCPtrListStrings( + base.SupportContact, + ) + server.profiles = getCPtrProfiles(&base.Profiles) + // No endtime is given if we get servers when it has been partially initialised + if base.EndTime.IsZero() { + server.expire_time = C.ulonglong(0) + } + // The expire time should be stored as an unsigned long long in unix itme + server.expire_time = C.ulonglong(base.EndTime.Unix()) + return server +} + +// Function for freeing a single server +// Gets the pointer to C struct +func freeServer(info *C.server) { + // Free strings + C.free(unsafe.Pointer(info.identifier)) + C.free(unsafe.Pointer(info.display_name)) + + // Free arrays + freeCListStrings(info.support_contact, info.total_support_contact) + freeCProfiles(info.profiles) + + // Free the struct itself + C.free(unsafe.Pointer(info)) +} + +// Get the C ptr to the servers, returns the length in size_t and the double pointer to the struct +func getCPtrServers( + state *eduvpn.VPNState, + serverMap map[string]*server.InstituteAccessServer, +) (C.size_t, **C.server) { + totalServers := C.size_t(len(serverMap)) + // If we have servers, which is not always the case + if totalServers > 0 { + serversPtr := (**C.server)(C.malloc(totalServers * C.size_t(unsafe.Sizeof(uintptr(0))))) + servers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(serversPtr))[:totalServers:totalServers] + index := 0 + for _, server := range serverMap { + cServer := getServer(state, &server.Base) + servers[index] = cServer + index += 1 + } + } + return C.size_t(0), nil +} + +//export FreeServers +// This function takes the servers as a C struct pointer as input +// It frees all allocated memory for the server +func FreeServers(cServers *C.servers) { + // Free the custom servers if there are any + if cServers.total_custom > 0 { + customServers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(cServers.custom_servers))[:cServers.total_custom:cServers.total_custom] + for i := C.size_t(0); i < cServers.total_custom; i++ { + freeServer(customServers[i]) + } + C.free(unsafe.Pointer(cServers.custom_servers)) + } + // Free the institute access servers if there are any + if cServers.total_institute > 0 { + instituteServers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(cServers.institute_servers))[:cServers.total_institute:cServers.total_institute] + + for i := C.size_t(0); i < cServers.total_institute; i++ { + freeServer(instituteServers[i]) + } + C.free(unsafe.Pointer(cServers.institute_servers)) + } + // Free the secure internet server if there is one + if cServers.secure_internet_server != nil { + C.free(unsafe.Pointer(cServers.secure_internet_server.country_code)) + freeServer(cServers.secure_internet_server) + } + // Free the structure itself + C.free(unsafe.Pointer(cServers)) +} + +// Return the servers as a C struct pointer +// It takes the state as a pointer as we need to translate some strings +// It also takes the servers as a pointer that belongs to the main state or gathered from the callback +func getSavedServersWithOptions(state *eduvpn.VPNState, servers *server.Servers) *C.servers { + // Allocate the struct that we will return + // With the size of the c struct + returnedStruct := (*C.servers)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{})))) + + // Get the different categories of servers + totalCustom, customPtr := getCPtrServers(state, servers.CustomServers.Map) + totalInstitute, institutePtr := getCPtrServers(state, servers.InstituteServers.Map) + var secureServerPtr *C.server = nil + secureInternetBase, secureInternetBaseErr := servers.SecureInternetHomeServer.GetBase() + if secureInternetBaseErr == nil && secureInternetBase != nil { + // FIXME: log error? + secureServerPtr = getServer(state, secureInternetBase) + // Give a new identifier + C.free(unsafe.Pointer(secureServerPtr.identifier)) + secureServerPtr.identifier = C.CString(servers.SecureInternetHomeServer.HomeOrganizationID) + secureServerPtr.country_code = C.CString(servers.SecureInternetHomeServer.CurrentLocation) + } + + // Fill the struct and return + returnedStruct.custom_servers = customPtr + returnedStruct.total_custom = totalCustom + returnedStruct.institute_servers = institutePtr + returnedStruct.total_institute = totalInstitute + returnedStruct.secure_internet_server = secureServerPtr + return returnedStruct +} + +//export GetSavedServers +// This function takes the name as input which is the name of the client +// It gets the state by name and then returns the saved servers as a c struct belonging to it +func GetSavedServers(name *C.char) *C.servers { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + if stateErr != nil { + // TODO: Remove this panic + panic(stateErr) + } + return getSavedServersWithOptions(state, &state.Servers) +} + +// This function takes the state as input which is the main state +// It also takes the data as an interface and if it has the servers type gets the data as a c struct otherwise nil +func getTransitionDataServers(state *eduvpn.VPNState, data interface{}) *C.servers { + if converted, ok := data.(server.Servers); ok { + return getSavedServersWithOptions(state, &converted) + } + return nil +} + +//export FreeSecureLocations +func FreeSecureLocations(locations *C.serverLocations) { + freeCListStrings(locations.locations, locations.total_locations); + C.free(unsafe.Pointer(locations)) +} + +func getTransitionSecureLocations(data interface{}) (*C.serverLocations) { + if locations, ok := data.([]string); ok { + returnedStruct := (*C.serverLocations)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{})))) + returnedStruct.total_locations, returnedStruct.locations = getCPtrListStrings(locations) + return returnedStruct + } + return nil +} diff --git a/fsm.go b/fsm.go index 31ed30b..a09e335 100644 --- a/fsm.go +++ b/fsm.go @@ -3,14 +3,17 @@ package eduvpn import ( "errors" "fmt" - "github.com/jwijenbergh/eduvpn-common/internal/fsm" - "github.com/jwijenbergh/eduvpn-common/internal/types" + + "github.com/jwijenbergh/eduvpn-common/internal/fsm" + "github.com/jwijenbergh/eduvpn-common/internal/types" ) -type FSMStateID = fsm.FSMStateID -type FSMStates = fsm.FSMStates -type FSMState = fsm.FSMState -type FSMTransition = fsm.FSMTransition +type ( + FSMStateID = fsm.FSMStateID + FSMStates = fsm.FSMStates + FSMState = fsm.FSMState + FSMTransition = fsm.FSMTransition +) const ( // Deregistered means the app is not registered with the wrapper @@ -91,9 +94,16 @@ func GetStateName(s FSMStateID) string { } } -func newFSM(name string, callback func(FSMStateID, FSMStateID, interface{}), directory string, debug bool) fsm.FSM { +func newFSM( + name string, + callback func(FSMStateID, FSMStateID, interface{}), + directory string, + debug bool, +) fsm.FSM { states := FSMStates{ - STATE_DEREGISTERED: FSMState{Transitions: []FSMTransition{{STATE_NO_SERVER, "Client registers"}}}, + STATE_DEREGISTERED: FSMState{ + Transitions: []FSMTransition{{STATE_NO_SERVER, "Client registers"}}, + }, STATE_NO_SERVER: FSMState{ Transitions: []FSMTransition{ {STATE_NO_SERVER, "Reload list"}, @@ -231,7 +241,11 @@ func (e FSMWrongStateError) CustomError() *types.WrappedErrorMessage { return &types.WrappedErrorMessage{ Message: "Wrong FSM State", Err: errors.New( - fmt.Sprintf("wrong FSM state, got: %s, want: %s", GetStateName(e.Got), GetStateName(e.Want)), + fmt.Sprintf( + "wrong FSM state, got: %s, want: %s", + GetStateName(e.Got), + GetStateName(e.Want), + ), ), } } diff --git a/go.mod b/go.mod index 88bb6f1..30b1d70 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,10 @@ go 1.15 require ( github.com/jedisct1/go-minisign v0.0.0-20211028175153-1c139d1cc84b - golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 // indirect golang.zx2c4.com/wireguard/wgctrl v0.0.0-20220420130459-88a4932fb60b ) + +require ( + golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect + golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 // indirect +) diff --git a/go.sum b/go.sum index d536cee..6ecd05e 100644 --- a/go.sum +++ b/go.sum @@ -43,7 +43,6 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d h1:q4JksJ2n0fmbXC0Aj0eOs6E0AcPqnKglxWXWFqGD6x0= diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 3ab13b3..21125cb 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -163,22 +163,22 @@ func (discovery *Discovery) DetermineServersUpdate() bool { } // Get the organization list -func (discovery *Discovery) GetOrganizationsList() (string, error) { +func (discovery *Discovery) GetOrganizationsList() (*types.DiscoveryOrganizations, error) { if !discovery.DetermineOrganizationsUpdate() { - return discovery.Organizations.RawString, nil + return &discovery.Organizations, nil } file := "organization_list.json" body, bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations) if bodyErr != nil { // Return previous with an error - return discovery.Organizations.RawString, &types.WrappedErrorMessage{ + return &discovery.Organizations, &types.WrappedErrorMessage{ Message: "failed getting organizations in Discovery", Err: bodyErr, } } discovery.Organizations.RawString = body discovery.Organizations.Timestamp = util.GetCurrentTime() - return discovery.Organizations.RawString, nil + return &discovery.Organizations, nil } // Get the server list @@ -206,7 +206,10 @@ type GetOrgByIDNotFoundError struct { } func (e GetOrgByIDNotFoundError) Error() string { - return fmt.Sprintf("No Secure Internet Home found in organizations with ID %s. Please choose your server again", e.ID) + return fmt.Sprintf( + "No Secure Internet Home found in organizations with ID %s. Please choose your server again", + e.ID, + ) } type GetServerByURLNotFoundError struct { @@ -240,5 +243,8 @@ type GetSecureHomeArgsNotFoundError struct { } func (e GetSecureHomeArgsNotFoundError) Error() string { - return fmt.Sprintf("No Secure Internet Home found with URL: %s. Please choose your server again", e.URL) + return fmt.Sprintf( + "No Secure Internet Home found with URL: %s. Please choose your server again", + e.URL, + ) } diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go index 63c9ac2..292e09e 100644 --- a/internal/fsm/fsm.go +++ b/internal/fsm/fsm.go @@ -1,9 +1,9 @@ package fsm import ( + "fmt" "os" "os/exec" - "fmt" "path" "sort" ) @@ -149,7 +149,13 @@ func (fsm *FSM) generateMermaidGraph() string { } else { graph += "\nstyle " + fsm.GetName(state) + " fill:white\n" } - graph += fsm.GetName(state) + "(" + fsm.GetName(state) + ") " + "-->|" + transition.Description + "| " + fsm.GetName(transition.To) + "\n" + graph += fsm.GetName( + state, + ) + "(" + fsm.GetName( + state, + ) + ") " + "-->|" + transition.Description + "| " + fsm.GetName( + transition.To, + ) + "\n" } } return graph @@ -162,4 +168,3 @@ func (fsm *FSM) GenerateGraph() string { return "" } - diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 89854f7..10abf4a 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -315,7 +315,10 @@ func (oauth *OAuth) EnsureTokens() error { errorMessage := "failed ensuring OAuth tokens" // Access Token or Refresh Tokens empty, we can not ensure the tokens if oauth.Token.Access == "" && oauth.Token.Refresh == "" { - return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: "tokens are empty"}} + return &types.WrappedErrorMessage{ + Message: errorMessage, + Err: &OAuthTokensInvalidError{Cause: "tokens are empty"}, + } } // We have tokens... @@ -330,7 +333,12 @@ func (oauth *OAuth) EnsureTokens() error { // We have obtained new tokens with refresh if refreshErr != nil { // We have failed to ensure the tokens due to refresh not working - return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr)}} + return &types.WrappedErrorMessage{ + Message: errorMessage, + Err: &OAuthTokensInvalidError{ + Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr), + }, + } } return nil diff --git a/internal/server/custom.go b/internal/server/custom.go index c757f76..a93242d 100644 --- a/internal/server/custom.go +++ b/internal/server/custom.go @@ -3,4 +3,3 @@ package server func (servers *Servers) RemoveCustomServer(url string) { servers.CustomServers.Remove(url) } - diff --git a/internal/types/server.go b/internal/types/server.go index a9c46b2..33a6e9c 100644 --- a/internal/types/server.go +++ b/internal/types/server.go @@ -17,12 +17,10 @@ type DiscoveryOrganizations struct { } type DiscoveryOrganization struct { - DisplayName map[string]string `json:"display_name"` - OrgId string `json:"org_id"` - SecureInternetHome string `json:"secure_internet_home"` - KeywordList struct { - En string `json:"en"` - } `json:"keyword_list"` + DisplayName DiscoMapOrString `json:"display_name"` + OrgId string `json:"org_id"` + SecureInternetHome string `json:"secure_internet_home"` + KeywordList DiscoMapOrString `json:"keyword_list"` } // Structs that define the json format for @@ -34,11 +32,11 @@ type DiscoveryServers struct { RawString string `json:"go_raw_string"` } -type DNMapOrString map[string]string +type DiscoMapOrString map[string]string // The display name can either be a map or a string in the server list // Unmarshal it by first trying a string and then the map -func (DN *DNMapOrString) UnmarshalJSON(data []byte) error { +func (DN *DiscoMapOrString) UnmarshalJSON(data []byte) error { var displayNameString string err := json.Unmarshal(data, &displayNameString) @@ -60,11 +58,11 @@ func (DN *DNMapOrString) UnmarshalJSON(data []byte) error { } type DiscoveryServer struct { - AuthenticationURLTemplate string `json:"authentication_url_template"` - BaseURL string `json:"base_url"` - CountryCode string `json:"country_code"` - DisplayName DNMapOrString `json:"display_name,omitempty"` - PublicKeyList []string `json:"public_key_list"` - Type string `json:"server_type"` - SupportContact []string `json:"support_contact"` + AuthenticationURLTemplate string `json:"authentication_url_template"` + BaseURL string `json:"base_url"` + CountryCode string `json:"country_code"` + DisplayName DiscoMapOrString `json:"display_name,omitempty"` + PublicKeyList []string `json:"public_key_list"` + Type string `json:"server_type"` + SupportContact []string `json:"support_contact"` } diff --git a/internal/util/util.go b/internal/util/util.go index b104476..e652779 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -83,3 +83,52 @@ func ReplaceWAYF(authTemplate string, authURL string, orgID string) string { authTemplate = strings.Replace(authTemplate, "@ORG_ID@", WAYFEncode(orgID), 1) return authTemplate } + +// https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY.md#language-matching +func GetLanguageMatched(languageMap map[string]string, languageTag string) string { + // If no or empty map is given, return the empty string + if languageMap == nil || len(languageMap) == 0 { + return "" + } + // Try to find the exact match + if val, ok := languageMap[languageTag]; ok { + return val + } + // Try to find a key that starts with the OS language setting + for k := range languageMap { + if strings.HasPrefix(k, languageTag) { + return languageMap[k] + } + } + // Try to find a key that starts with the first part of the OS language (e.g. de-) + splitted := strings.Split(languageTag, "-") + // We have a "-" + if len(splitted) > 1 { + for k := range languageMap { + if strings.HasPrefix(k, splitted[0]+"-") { + return languageMap[k] + } + } + } + // search for just the language (e.g. de) + for k := range languageMap { + if k == splitted[0] { + return languageMap[k] + } + } + + // Pick one that is deemed best, e.g. en-US or en, but note that not all languages are always available! + // We force an entry that is english exactly or with an english prefix + for k := range languageMap { + if k == "en" || strings.HasPrefix(k, "en-") { + return languageMap[k] + } + } + + // Otherwise just return one + for k := range languageMap { + return languageMap[k] + } + + return "" +} diff --git a/internal/util/util_test.go b/internal/util/util_test.go index 9eda14e..31d01e7 100644 --- a/internal/util/util_test.go +++ b/internal/util/util_test.go @@ -80,7 +80,11 @@ func Test_WAYFEncode(t *testing.T) { func Test_ReplaceWAYF(t *testing.T) { // We expect url encoding but the spaces to be correctly replace with a + instead of a %20 // And we expect that the return to and org_id are correctly replaced - replaced := ReplaceWAYF("@RETURN_TO@@ORG_ID@", "127.0.0.1:8000/&%$3#kM_- ", "idp-test.nl.org/") + replaced := ReplaceWAYF( + "@RETURN_TO@@ORG_ID@", + "127.0.0.1:8000/&%$3#kM_- ", + "idp-test.nl.org/", + ) wantReplaced := "127.0.0.1%3A8000%2F%26%25%243%23kM_-++++++++++++idp-test.nl.org%2F" if replaced != wantReplaced { t.Fatalf("Got: %s, want: %s", replaced, wantReplaced) @@ -114,3 +118,49 @@ func Test_ReplaceWAYF(t *testing.T) { t.Fatalf("Got: %s, want: %s", replaced, wantReplaced) } } + +func Test_GetLanguageMatched(t *testing.T) { + // func GetLanguageMatched(languageMap map[string]string, languageTag string) string { + + // exact match + returned := GetLanguageMatched(map[string]string{"en": "test", "de": "test2"}, "en") + if returned != "test" { + t.Fatalf("Got: %s, want: %s", returned, "test") + } + + // starts with language tag + returned = GetLanguageMatched(map[string]string{"en-US-test": "test", "de": "test2"}, "en-US") + if returned != "test" { + t.Fatalf("Got: %s, want: %s", returned, "test") + } + + // starts with en- + returned = GetLanguageMatched(map[string]string{"en-UK": "test", "en": "test2"}, "en-US") + if returned != "test" { + t.Fatalf("Got: %s, want: %s", returned, "test") + } + + // exact match for en + returned = GetLanguageMatched(map[string]string{"de": "test", "en": "test2"}, "en-US") + if returned != "test2" { + t.Fatalf("Got: %s, want: %s", returned, "test2") + } + + // We default to english + returned = GetLanguageMatched(map[string]string{"es": "test", "en": "test2"}, "nl-NL") + if returned != "test2" { + t.Fatalf("Got: %s, want: %s", returned, "test2") + } + + // We default to english with a - as well + returned = GetLanguageMatched(map[string]string{"est": "test", "en-": "test2"}, "en-US") + if returned != "test2" { + t.Fatalf("Got: %s, want: %s", returned, "test2") + } + + // None found just return one + returned = GetLanguageMatched(map[string]string{"es": "test"}, "en-US") + if returned != "test" { + t.Fatalf("Got: %s, want: %s", returned, "test") + } +} diff --git a/state.go b/state.go index 1354ffc..4911ac9 100644 --- a/state.go +++ b/state.go @@ -14,9 +14,15 @@ import ( "github.com/jwijenbergh/eduvpn-common/internal/util" ) -type ServerInfo = server.ServerInfoScreen +type ( + ServerInfo = server.ServerInfoScreen + VPNServerBase = server.ServerBase +) type VPNState struct { + // The language used for language matching + Language string `json:"-"` // language should not be saved + // The chosen server Servers server.Servers `json:"servers"` @@ -40,6 +46,7 @@ func (state *VPNState) GetSavedServers() *server.ServersConfiguredScreen { return state.Servers.GetServersConfigured() } +// TODO: Refactor this to a `New` method? func (state *VPNState) Register( name string, directory string, @@ -55,6 +62,7 @@ func (state *VPNState) Register( } // Initialize the logger logLevel := log.LOG_WARNING + state.Language = "en" if debug { logLevel = log.LOG_INFO @@ -83,18 +91,28 @@ func (state *VPNState) Register( _, currentServerErr := state.Servers.GetCurrentServer() // Only actually return the error if we have no disco servers and no current server if discoServersErr != nil && discoServers == "" && currentServerErr != nil { - state.Logger.Error(fmt.Sprintf("No configured servers, discovery servers is empty and no servers with error: %s", GetErrorTraceback(discoServersErr))) + state.Logger.Error( + fmt.Sprintf( + "No configured servers, discovery servers is empty and no servers with error: %s", + GetErrorTraceback(discoServersErr), + ), + ) return &types.WrappedErrorMessage{Message: errorMessage, Err: discoServersErr} } discoOrgs, discoOrgsErr := state.GetDiscoOrganizations() // Only actually return the error if we have no disco organizations and no current server - if discoOrgsErr != nil && discoOrgs == "" && currentServerErr != nil { - state.Logger.Error(fmt.Sprintf("No configured organizations, discovery organizations empty and no servers with error: %s", GetErrorTraceback(discoOrgsErr))) + if discoOrgsErr != nil && discoOrgs.Version == 0 && currentServerErr != nil { + state.Logger.Error( + fmt.Sprintf( + "No configured organizations, discovery organizations empty and no servers with error: %s", + GetErrorTraceback(discoOrgsErr), + ), + ) return &types.WrappedErrorMessage{Message: errorMessage, Err: discoOrgsErr} } // Go to the No Server state with the saved servers - state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), true) + state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, true) return nil } @@ -121,7 +139,7 @@ func (state *VPNState) GoBack() error { } // FIXME: Abitrary back transitions don't work because we need the approriate data - state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), false) + state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, false) // state.FSM.GoBack() return nil } @@ -157,7 +175,10 @@ func (state *VPNState) ensureLogin(chosenServer server.Server) error { return nil } -func (state *VPNState) getConfigAuth(chosenServer server.Server, forceTCP bool) (string, string, error) { +func (state *VPNState) getConfigAuth( + chosenServer server.Server, + forceTCP bool, +) (string, string, error) { loginErr := state.ensureLogin(chosenServer) if loginErr != nil { return "", "", loginErr @@ -181,7 +202,10 @@ func (state *VPNState) getConfigAuth(chosenServer server.Server, forceTCP bool) return server.GetConfig(chosenServer, forceTCP) } -func (state *VPNState) retryConfigAuth(chosenServer server.Server, forceTCP bool) (string, string, error) { +func (state *VPNState) retryConfigAuth( + chosenServer server.Server, + forceTCP bool, +) (string, string, error) { errorMessage := "failed authorized config retry" config, configType, configErr := state.getConfigAuth(chosenServer, forceTCP) if configErr != nil { @@ -189,10 +213,16 @@ func (state *VPNState) retryConfigAuth(chosenServer server.Server, forceTCP bool // Only retry if the error is that the tokens are invalid if errors.As(configErr, &error) { - retryConfig, retryConfigType, retryConfigErr := state.getConfigAuth(chosenServer, forceTCP) + retryConfig, retryConfigType, retryConfigErr := state.getConfigAuth( + chosenServer, + forceTCP, + ) if retryConfigErr != nil { state.GoBack() - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: retryConfigErr} + return "", "", &types.WrappedErrorMessage{ + Message: errorMessage, + Err: retryConfigErr, + } } return retryConfig, retryConfigType, nil } @@ -202,7 +232,6 @@ func (state *VPNState) retryConfigAuth(chosenServer server.Server, forceTCP bool return config, configType, nil } - func (state *VPNState) getConfig( chosenServer server.Server, forceTCP bool, @@ -221,8 +250,13 @@ func (state *VPNState) getConfig( return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } + currentServer, currentServerErr := state.Servers.GetCurrentServer() + if currentServerErr != nil { + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} + } + // Signal the server display info - state.FSM.GoTransitionWithData(STATE_DISCONNECTED, state.getServerInfoData(), false) + state.FSM.GoTransitionWithData(STATE_DISCONNECTED, currentServer, false) // Save the config state.Config.Save(&state) @@ -235,14 +269,25 @@ func (state *VPNState) SetSecureLocation(countryCode string) error { server, serverErr := state.Discovery.GetServerByCountryCode(countryCode, "secure_internet") if serverErr != nil { - state.Logger.Error(fmt.Sprintf("Failed getting secure internet server by country code: %s with error: %s", countryCode, GetErrorTraceback(serverErr))) + state.Logger.Error( + fmt.Sprintf( + "Failed getting secure internet server by country code: %s with error: %s", + countryCode, + GetErrorTraceback(serverErr), + ), + ) state.GoBack() return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } setLocationErr := state.Servers.SetSecureLocation(server) if setLocationErr != nil { - state.Logger.Error(fmt.Sprintf("Failed setting secure internet server with error: %s", GetErrorTraceback(serverErr))) + state.Logger.Error( + fmt.Sprintf( + "Failed setting secure internet server with error: %s", + GetErrorTraceback(serverErr), + ), + ) state.GoBack() return &types.WrappedErrorMessage{Message: errorMessage, Err: setLocationErr} } @@ -267,7 +312,10 @@ func (state *VPNState) askSecureLocation() error { // The state has changed, meaning setting the secure location was not successful if state.FSM.Current != STATE_ASK_LOCATION { // TODO: maybe a custom type for this errors.new? - return &types.WrappedErrorMessage{Message: "failed setting secure location", Err: errors.New("failed loading secure location")} + return &types.WrappedErrorMessage{ + Message: "failed setting secure location", + Err: errors.New("failed loading secure location"), + } } return nil } @@ -316,7 +364,7 @@ func (state *VPNState) RemoveSecureInternet() error { } // No error because we can only have one secure internet server and if there are no secure internet servers, this is a NO-OP state.Servers.RemoveSecureInternet() - state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), false) + state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, false) // Save the config state.Config.Save(&state) return nil @@ -331,7 +379,7 @@ func (state *VPNState) RemoveInstituteAccess(url string) error { } // No error because this is a NO-OP if the server doesn't exist state.Servers.RemoveInstituteAccess(url) - state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), false) + state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, false) // Save the config state.Config.Save(&state) return nil @@ -346,7 +394,7 @@ func (state *VPNState) RemoveCustomServer(url string) error { } // No error because this is a NO-OP if the server doesn't exist state.Servers.RemoveCustomServer(url) - state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), false) + state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, false) // Save the config state.Config.Save(&state) return nil @@ -363,7 +411,12 @@ func (state *VPNState) GetConfigSecureInternet( state.FSM.GoTransition(STATE_LOADING_SERVER) server, serverErr := state.addSecureInternetHomeServer(orgID) if serverErr != nil { - state.Logger.Error(fmt.Sprintf("Failed adding a secure internet server with error: %s", GetErrorTraceback(serverErr))) + state.Logger.Error( + fmt.Sprintf( + "Failed adding a secure internet server with error: %s", + GetErrorTraceback(serverErr), + ), + ) state.GoBack() return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } @@ -372,7 +425,12 @@ func (state *VPNState) GetConfigSecureInternet( config, configType, configErr := state.getConfig(server, forceTCP) if configErr != nil { - state.Logger.Error(fmt.Sprintf("Failed getting a secure internet configuration with error: %s", GetErrorTraceback(configErr))) + state.Logger.Error( + fmt.Sprintf( + "Failed getting a secure internet configuration with error: %s", + GetErrorTraceback(configErr), + ), + ) return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } return config, configType, nil @@ -425,14 +483,24 @@ func (state *VPNState) GetConfigInstituteAccess(url string, forceTCP bool) (stri state.FSM.GoTransition(STATE_LOADING_SERVER) server, serverErr := state.addInstituteServer(url) if serverErr != nil { - state.Logger.Error(fmt.Sprintf("Failed adding an institute access server with error: %s", GetErrorTraceback(serverErr))) + state.Logger.Error( + fmt.Sprintf( + "Failed adding an institute access server with error: %s", + GetErrorTraceback(serverErr), + ), + ) state.GoBack() return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } config, configType, configErr := state.getConfig(server, forceTCP) if configErr != nil { - state.Logger.Error(fmt.Sprintf("Failed getting an institute access server configuration with error: %s", GetErrorTraceback(configErr))) + state.Logger.Error( + fmt.Sprintf( + "Failed getting an institute access server configuration with error: %s", + GetErrorTraceback(configErr), + ), + ) return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } return config, configType, nil @@ -444,14 +512,24 @@ func (state *VPNState) GetConfigCustomServer(url string, forceTCP bool) (string, server, serverErr := state.addCustomServer(url) if serverErr != nil { - state.Logger.Error(fmt.Sprintf("Failed adding a custom server with error: %s", GetErrorTraceback(serverErr))) + state.Logger.Error( + fmt.Sprintf( + "Failed adding a custom server with error: %s", + GetErrorTraceback(serverErr), + ), + ) state.GoBack() return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } config, configType, configErr := state.getConfig(server, forceTCP) if configErr != nil { - state.Logger.Error(fmt.Sprintf("Failed getting a custom server with error: %s", GetErrorTraceback(configErr))) + state.Logger.Error( + fmt.Sprintf( + "Failed getting a custom server with error: %s", + GetErrorTraceback(configErr), + ), + ) return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } return config, configType, nil @@ -472,7 +550,12 @@ func (state *VPNState) CancelOAuth() error { currentServer, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { - state.Logger.Warning(fmt.Sprintf("Failed cancelling OAuth, no server configured to cancel OAuth for (err: %v)", serverErr)) + state.Logger.Warning( + fmt.Sprintf( + "Failed cancelling OAuth, no server configured to cancel OAuth for (err: %v)", + serverErr, + ), + ) return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } server.CancelOAuth(currentServer) @@ -486,27 +569,43 @@ func (state *VPNState) ChangeSecureLocation() error { state.Logger.Error("Failed changing secure internet location, not in the right state") return &types.WrappedErrorMessage{ Message: errorMessage, - Err: FSMWrongStateError{Got: state.FSM.Current, Want: STATE_NO_SERVER}.CustomError(), + Err: FSMWrongStateError{ + Got: state.FSM.Current, + Want: STATE_NO_SERVER, + }.CustomError(), } } askLocationErr := state.askSecureLocation() if askLocationErr != nil { - state.Logger.Error(fmt.Sprintf("Failed changing secure internet location, err: %s", GetErrorTraceback(askLocationErr))) + state.Logger.Error( + fmt.Sprintf( + "Failed changing secure internet location, err: %s", + GetErrorTraceback(askLocationErr), + ), + ) return &types.WrappedErrorMessage{Message: errorMessage, Err: askLocationErr} } // Go back to the main screen - state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), false) + state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, false) return nil } -func (state *VPNState) GetDiscoOrganizations() (string, error) { +func (state *VPNState) GetDiscoOrganizations() (*types.DiscoveryOrganizations, error) { orgs, orgsErr := state.Discovery.GetOrganizationsList() if orgsErr != nil { - state.Logger.Warning(fmt.Sprintf("Failed getting discovery organizations, Err: %s", GetErrorTraceback(orgsErr))) - return "", &types.WrappedErrorMessage{Message: "failed getting discovery organizations list", Err: orgsErr} + state.Logger.Warning( + fmt.Sprintf( + "Failed getting discovery organizations, Err: %s", + GetErrorTraceback(orgsErr), + ), + ) + return nil, &types.WrappedErrorMessage{ + Message: "failed getting discovery organizations list", + Err: orgsErr, + } } return orgs, nil } @@ -514,8 +613,13 @@ func (state *VPNState) GetDiscoOrganizations() (string, error) { func (state *VPNState) GetDiscoServers() (string, error) { servers, serversErr := state.Discovery.GetServersList() if serversErr != nil { - state.Logger.Warning(fmt.Sprintf("Failed getting discovery servers, Err: %s", GetErrorTraceback(serversErr))) - return "", &types.WrappedErrorMessage{Message: "failed getting discovery servers list", Err: serversErr} + state.Logger.Warning( + fmt.Sprintf("Failed getting discovery servers, Err: %s", GetErrorTraceback(serversErr)), + ) + return "", &types.WrappedErrorMessage{ + Message: "failed getting discovery servers list", + Err: serversErr, + } } return servers, nil } @@ -524,14 +628,21 @@ func (state *VPNState) SetProfileID(profileID string) error { errorMessage := "failed to set the profile ID for the current server" server, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { - state.Logger.Warning(fmt.Sprintf("Failed setting a profile ID because no server configured, Err: %s", GetErrorTraceback(serverErr))) + state.Logger.Warning( + fmt.Sprintf( + "Failed setting a profile ID because no server configured, Err: %s", + GetErrorTraceback(serverErr), + ), + ) state.GoBack() return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } base, baseErr := server.GetBase() if baseErr != nil { - state.Logger.Error(fmt.Sprintf("Failed setting a profile ID, Err: %s", GetErrorTraceback(serverErr))) + state.Logger.Error( + fmt.Sprintf("Failed setting a profile ID, Err: %s", GetErrorTraceback(serverErr)), + ) state.GoBack() return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } @@ -541,7 +652,12 @@ func (state *VPNState) SetProfileID(profileID string) error { func (state *VPNState) SetSearchServer() error { if !state.FSM.HasTransition(STATE_SEARCH_SERVER) { - state.Logger.Warning(fmt.Sprintf("Failed setting search server, wrong state %s", GetStateName(state.FSM.Current))) + state.Logger.Warning( + fmt.Sprintf( + "Failed setting search server, wrong state %s", + GetStateName(state.FSM.Current), + ), + ) return &types.WrappedErrorMessage{ Message: "failed to set search server", Err: FSMWrongStateTransitionError{ @@ -558,21 +674,32 @@ func (state *VPNState) SetSearchServer() error { func (state *VPNState) getServerInfoData() *server.ServerInfoScreen { info, infoErr := state.Servers.GetCurrentServerInfo() if infoErr != nil { - state.Logger.Error(fmt.Sprintf("Failed getting server info data with error: %s", GetErrorTraceback(infoErr))) + state.Logger.Error( + fmt.Sprintf( + "Failed getting server info data with error: %s", + GetErrorTraceback(infoErr), + ), + ) } return info } func (state *VPNState) SetConnected() error { + errorMessage := "failed to set connected" if state.InFSMState(STATE_CONNECTED) { // already connected, show no error state.Logger.Warning("Already connected") return nil } if !state.FSM.HasTransition(STATE_CONNECTED) { - state.Logger.Warning(fmt.Sprintf("Failed setting connected, wrong state: %s", GetStateName(state.FSM.Current))) + state.Logger.Warning( + fmt.Sprintf( + "Failed setting connected, wrong state: %s", + GetStateName(state.FSM.Current), + ), + ) return &types.WrappedErrorMessage{ - Message: "failed to set connected", + Message: errorMessage, Err: FSMWrongStateTransitionError{ Got: state.FSM.Current, Want: STATE_CONNECTED, @@ -580,7 +707,18 @@ func (state *VPNState) SetConnected() error { } } - state.FSM.GoTransitionWithData(STATE_CONNECTED, state.getServerInfoData(), false) + currentServer, currentServerErr := state.Servers.GetCurrentServer() + if currentServerErr != nil { + state.Logger.Warning( + fmt.Sprintf( + "Failed setting connected, cannot get current server with error: %s", + GetErrorTraceback(currentServerErr), + ), + ) + return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} + } + + state.FSM.GoTransitionWithData(STATE_CONNECTED, currentServer, false) return nil } @@ -591,7 +729,12 @@ func (state *VPNState) SetConnecting() error { return nil } if !state.FSM.HasTransition(STATE_CONNECTING) { - state.Logger.Warning(fmt.Sprintf("Failed setting connecting, wrong state: %s", GetStateName(state.FSM.Current))) + state.Logger.Warning( + fmt.Sprintf( + "Failed setting connecting, wrong state: %s", + GetStateName(state.FSM.Current), + ), + ) return &types.WrappedErrorMessage{ Message: "failed to set connecting", Err: FSMWrongStateTransitionError{ @@ -606,15 +749,21 @@ func (state *VPNState) SetConnecting() error { } func (state *VPNState) SetDisconnecting() error { + errorMessage := "failed to set disconnecting" if state.InFSMState(STATE_DISCONNECTING) { // already disconnecting, show no error state.Logger.Warning("Already disconnecting") return nil } if !state.FSM.HasTransition(STATE_DISCONNECTING) { - state.Logger.Warning(fmt.Sprintf("Failed setting disconnecting, wrong state: %s", GetStateName(state.FSM.Current))) + state.Logger.Warning( + fmt.Sprintf( + "Failed setting disconnecting, wrong state: %s", + GetStateName(state.FSM.Current), + ), + ) return &types.WrappedErrorMessage{ - Message: "failed to set disconnecting", + Message: errorMessage, Err: FSMWrongStateTransitionError{ Got: state.FSM.Current, Want: STATE_DISCONNECTING, @@ -622,7 +771,18 @@ func (state *VPNState) SetDisconnecting() error { } } - state.FSM.GoTransitionWithData(STATE_DISCONNECTING, state.getServerInfoData(), false) + currentServer, currentServerErr := state.Servers.GetCurrentServer() + if currentServerErr != nil { + state.Logger.Warning( + fmt.Sprintf( + "Failed setting disconnected, cannot get current server with error: %s", + GetErrorTraceback(currentServerErr), + ), + ) + return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} + } + + state.FSM.GoTransitionWithData(STATE_DISCONNECTING, currentServer, false) return nil } @@ -634,7 +794,12 @@ func (state *VPNState) SetDisconnected(cleanup bool) error { return nil } if !state.FSM.HasTransition(STATE_DISCONNECTED) { - state.Logger.Warning(fmt.Sprintf("Failed setting disconnected, wrong state: %s", GetStateName(state.FSM.Current))) + state.Logger.Warning( + fmt.Sprintf( + "Failed setting disconnected, wrong state: %s", + GetStateName(state.FSM.Current), + ), + ) return &types.WrappedErrorMessage{ Message: errorMessage, Err: FSMWrongStateTransitionError{ @@ -644,18 +809,23 @@ func (state *VPNState) SetDisconnected(cleanup bool) error { } } + currentServer, currentServerErr := state.Servers.GetCurrentServer() + if currentServerErr != nil { + state.Logger.Warning( + fmt.Sprintf( + "Failed setting disconnect, failed getting current server with error: %s", + GetErrorTraceback(currentServerErr), + ), + ) + return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} + } + if cleanup { // Do the /disconnect API call and go to disconnected after... - currentServer, currentServerErr := state.Servers.GetCurrentServer() - if currentServerErr != nil { - state.Logger.Warning(fmt.Sprintf("Failed getting current server to send /disconnect API call, error: %s", GetErrorTraceback(currentServerErr))) - return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} - } - server.Disconnect(currentServer) } - state.FSM.GoTransitionWithData(STATE_DISCONNECTED, state.getServerInfoData(), false) + state.FSM.GoTransitionWithData(STATE_DISCONNECTED, currentServer, false) return nil } @@ -665,13 +835,23 @@ func (state *VPNState) RenewSession() error { currentServer, currentServerErr := state.Servers.GetCurrentServer() if currentServerErr != nil { - state.Logger.Warning(fmt.Sprintf("Failed getting current server to renew, error: %s", GetErrorTraceback(currentServerErr))) + state.Logger.Warning( + fmt.Sprintf( + "Failed getting current server to renew, error: %s", + GetErrorTraceback(currentServerErr), + ), + ) return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} } loginErr := state.ensureLogin(currentServer) if loginErr != nil { - state.Logger.Warning(fmt.Sprintf("Failed logging in server for renew, error: %s", GetErrorTraceback(loginErr))) + state.Logger.Warning( + fmt.Sprintf( + "Failed logging in server for renew, error: %s", + GetErrorTraceback(loginErr), + ), + ) return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr} } @@ -679,7 +859,9 @@ func (state *VPNState) RenewSession() error { } func (state *VPNState) ShouldRenewButton() bool { - if !state.InFSMState(STATE_CONNECTED) && !state.InFSMState(STATE_CONNECTING) && !state.InFSMState(STATE_DISCONNECTED) && !state.InFSMState(STATE_DISCONNECTING) { + if !state.InFSMState(STATE_CONNECTED) && !state.InFSMState(STATE_CONNECTING) && + !state.InFSMState(STATE_DISCONNECTED) && + !state.InFSMState(STATE_DISCONNECTING) { return false } @@ -717,3 +899,7 @@ func GetErrorTraceback(err error) string { func GetErrorJSONString(err error) (string, error) { return types.GetErrorJSONString(err) } + +func (state *VPNState) GetTranslated(languages map[string]string) string { + return util.GetLanguageMatched(languages, state.Language) +} diff --git a/wrappers/python/main.py b/wrappers/python/main.py index 1ab29cc..0bd2502 100644 --- a/wrappers/python/main.py +++ b/wrappers/python/main.py @@ -3,6 +3,8 @@ from eduvpn_common.state import State, StateType import webbrowser import json import sys +import time +from typing import List # Asks the user for a profile index # It loops up until a valid input is given @@ -27,6 +29,11 @@ def ask_profile_input(total: int) -> int: # Sets up the callbacks using the provided class def setup_callbacks(_eduvpn: eduvpn.EduVPN) -> None: # The callback that starst OAuth + @_eduvpn.event.on(State.NO_SERVER, StateType.Enter) + def no_server(old_state: str, servers) -> None: + for server in servers: + print(type(server)) + print(server) # It needs to open the URL in the web browser @_eduvpn.event.on(State.OAUTH_STARTED, StateType.Enter) def oauth_initialized(old_state: str, url: str) -> None: @@ -34,31 +41,30 @@ def setup_callbacks(_eduvpn: eduvpn.EduVPN) -> None: webbrowser.open(url) @_eduvpn.event.on(State.ASK_LOCATION, StateType.Enter) - def ask_location(old_state: str, locations: str): - print("Locations: ", locations) - _eduvpn.set_secure_location("NL") + def ask_location(old_state: str, locations: List[str]): + _eduvpn.set_secure_location(locations[1]) - # The callback which asks the user for a profile - @_eduvpn.event.on(State.ASK_PROFILE, StateType.Enter) - def ask_profile(old_state: str, profiles: str): - print("Multiple profiles found, you need to select a profile:") + ## The callback which asks the user for a profile + #@_eduvpn.event.on(State.ASK_PROFILE, StateType.Enter) + #def ask_profile(old_state: str, profiles: str): + # print("Multiple profiles found, you need to select a profile:") - # Parse the profiles as JSON - data = json.loads(profiles) + # # Parse the profiles as JSON + # data = json.loads(profiles) - # Get a lits of profiles - profile_strings = [x["profile_id"] for x in data["info"]["profile_list"]] - total_profiles = len(profile_strings) + # # Get a lits of profiles + # profile_strings = [x["profile_id"] for x in data["info"]["profile_list"]] + # total_profiles = len(profile_strings) - # Create a list of the strings to standard output - for idx, profile in enumerate(profile_strings): - print(f"{idx+1}. {profile}") + # # Create a list of the strings to standard output + # for idx, profile in enumerate(profile_strings): + # print(f"{idx+1}. {profile}") - # Get the profile index from the user - profile_index = ask_profile_input(total_profiles) + # # Get the profile index from the user + # profile_index = ask_profile_input(total_profiles) - # Set the profile with the index - _eduvpn.set_profile(profile_strings[profile_index]) + # # Set the profile with the index + # _eduvpn.set_profile(profile_strings[profile_index]) # The main entry point @@ -72,18 +78,13 @@ if __name__ == "__main__": except Exception as e: print("Failed registering:", e) - server = input( - "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): " - ) - - # Ensure we have a valid http prefix - if not server.startswith("http"): - # https by default - server = "https://" + server + #server = input( + # "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): " + #) # Get a Wireguard/OpenVPN config try: - config, config_type = _eduvpn.get_config_custom_server(server) + config, config_type = _eduvpn.get_config_secure_internet("https://idp.geant.org") print(f"Got a config with type: {config_type} and contents:\n{config}") except Exception as e: print("Failed to connect:", e) diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py index 5b63651..db5484f 100644 --- a/wrappers/python/src/__init__.py +++ b/wrappers/python/src/__init__.py @@ -40,12 +40,66 @@ class ErrorLevel(Enum): ERR_OTHER = 0 ERR_INFO = 1 +class cServerLocations(Structure): + _fields_ = [ + ("locations", POINTER(c_char_p)), + ("total_locations", c_size_t) + ] + +class cDiscoveryOrganization(Structure): + _fields_ = [ + ("display_name", c_char_p), + ("org_id", c_char_p), + ("secure_internet_home", c_char_p), + ("keyword_list", c_char_p), + ] + +class cDiscoveryOrganizations(Structure): + _fields_ = [ + ("version", c_ulonglong), + ("organizations", POINTER(POINTER(cDiscoveryOrganization))), + ("total_organizations", c_size_t), + ] + +class cServerProfile(Structure): + _fields_ = [ + ("identifier", c_char_p), + ("display_name", c_char_p), + ("default_gateway", c_int), + ] + +class cServerProfiles(Structure): + _fields_ = [ + ("current", c_int), + ("profiles", POINTER(POINTER(cServerProfile))), + ("total_profiles", c_size_t), + ] + +class cServer(Structure): + _fields_ = [ + ("identifier", c_char_p), + ("display_name", c_char_p), + ("country_code", c_char_p), + ("support_contact", POINTER(c_char_p)), + ("total_support_contact", c_size_t), + ("profiles", POINTER(cServerProfiles)), + ("expire_time", c_ulonglong), + ] + +class cServers(Structure): + _fields_ = [ + ("custom_servers", POINTER(POINTER(cServer))), + ("total_custom", c_size_t), + ("institute_servers", POINTER(POINTER(cServer))), + ("total_institute", c_size_t), + ("secure_internet", POINTER(cServer)), + ] class DataError(Structure): _fields_ = [("data", c_void_p), ("error", c_void_p)] -VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_char_p) +VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) # Exposed functions # We have to use c_void_p instead of c_char_p to free it properly @@ -77,7 +131,7 @@ lib.Register.argtypes, lib.Register.restype = [ ], c_void_p lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [ c_char_p -], DataError +], c_void_p lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p @@ -96,8 +150,12 @@ lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p +lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None +lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [c_void_p], None +lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int +lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], c_void_p class WrappedError: @@ -139,6 +197,8 @@ def get_ptr_error(ptr: c_void_p) -> Optional[WrappedError]: if not error_json: return None + if "level" not in error_json: + return error_string level = error_json["level"] traceback = error_json["traceback"] cause = error_json["cause"] @@ -149,6 +209,9 @@ def get_error(ptr: c_void_p) -> str: error = get_ptr_error(ptr) if not error: return "" + + if not isinstance(error, WrappedError): + return error return error.cause @@ -161,7 +224,6 @@ def get_data_error(data_error: DataError) -> Tuple[str, str]: def get_bool(boolInt: c_int) -> bool: return boolInt == 1 - decode_map = { c_int: get_bool, c_void_p: get_error, diff --git a/wrappers/python/src/discovery.py b/wrappers/python/src/discovery.py new file mode 100644 index 0000000..80c08cf --- /dev/null +++ b/wrappers/python/src/discovery.py @@ -0,0 +1,43 @@ +from . import lib, cDiscoveryOrganizations +from ctypes import cast, POINTER + + +class DiscoOrganization: + def __init__(self, display_name, org_id, secure_internet_home, keyword_list): + self.display_name = display_name + self.org_id = org_id + self.secure_internet_home = secure_internet_home + self.keyword_list = keyword_list + + +class DiscoOrganizations: + def __init__(self, version, organizations): + self.version = version + self.organizations = organizations + + +def get_disco_organization(ptr): + if not ptr: + return None + + current_organization = ptr.contents + display_name = current_organization.display_name.decode("utf-8") + org_id = current_organization.org_id.decode("utf-8") + secure_internet_home = current_organization.secure_internet_home.decode("utf-8") + keyword_list = current_organization.keyword_list.decode("utf-8") + return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list) + + +def get_disco_organizations(ptr): + if ptr: + orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents + organizations = [] + if orgs.organizations: + for i in range(orgs.total_organizations): + current = get_disco_organization(orgs.organizations[i]) + if current is None: + continue + organizations.append(current) + lib.FreeDiscoOrganizations(ptr) + return DiscoOrganizations(orgs.version, organizations) + return None diff --git a/wrappers/python/src/event.py b/wrappers/python/src/event.py index d0740f8..0e0f5ae 100644 --- a/wrappers/python/src/event.py +++ b/wrappers/python/src/event.py @@ -1,7 +1,8 @@ -from . import VPNStateChange +from . import VPNStateChange, get_ptr_string from enum import Enum from typing import Callable -from .state import StateType +from .state import State, StateType +from .server import get_locations, get_servers EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback" @@ -15,6 +16,15 @@ def class_state_transition(state: int, state_type: StateType) -> Callable: return wrapper +def convert_data(state: State, data): + if not data: + return None + if state is State.NO_SERVER: + return get_servers(data) + if state is State.OAUTH_STARTED: + return get_ptr_string(data) + if state is State.ASK_LOCATION: + return get_locations(data) class EventHandler(object): def __init__(self): @@ -73,6 +83,7 @@ class EventHandler(object): def run(self, old_state: int, new_state: int, data: str) -> None: # First run leave transitions, then enter # The state is done when the wait event finishes - self.run_state(old_state, new_state, StateType.Leave, data) - self.run_state(new_state, old_state, StateType.Enter, data) - self.run_state(new_state, old_state, StateType.Wait, data) + converted = convert_data(new_state, data) + self.run_state(old_state, new_state, StateType.Leave, converted) + self.run_state(new_state, old_state, StateType.Enter, converted) + self.run_state(new_state, old_state, StateType.Wait, converted) diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py index b37842f..cbeadb5 100644 --- a/wrappers/python/src/main.py +++ b/wrappers/python/src/main.py @@ -1,8 +1,10 @@ from . import lib, VPNStateChange, encode_args, decode_res from typing import Optional, Tuple import threading +from .discovery import get_disco_organizations from .event import EventHandler from .state import State, StateType +from .server import get_servers import json eduvpn_objects = {} @@ -26,7 +28,7 @@ def state_callback(name, old_state, new_state, data): name = name.decode() if name not in eduvpn_objects: return - eduvpn_objects[name].callback(State(old_state), State(new_state), data.decode()) + eduvpn_objects[name].callback(State(old_state), State(new_state), data) class EduVPN(object): @@ -58,6 +60,12 @@ class EduVPN(object): res = func(self.name.encode("utf-8"), *(args_gen)) return decode_res(func.restype)(res) + def go_function_custom_decode(self, func, decode_func, *args): + # The functions all have at least one arg type which is the name of the client + args_gen = encode_args(list(args), func.argtypes[1:]) + res = func(self.name.encode("utf-8"), *(args_gen)) + return decode_func(res) + def cancel_oauth(self) -> None: cancel_oauth_err = self.go_function(lib.CancelOAuth) @@ -91,10 +99,9 @@ class EduVPN(object): return servers def get_disco_organizations(self) -> str: - organizations, organizations_err = self.go_function(lib.GetDiscoOrganizations) - - if organizations_err: - raise Exception(organizations_err) + organizations = self.go_function_custom_decode(lib.GetDiscoOrganizations, decode_func=get_disco_organizations) + #if organizations_err: + # raise Exception(organizations_err) return organizations @@ -196,7 +203,7 @@ class EduVPN(object): def event(self) -> EventHandler: return self.event_handler - def callback(self, old_state: State, new_state: State, data: str) -> None: + def callback(self, old_state: State, new_state: State, data) -> None: self.event.run(old_state, new_state, data) def set_profile(self, profile_id: str) -> None: @@ -242,3 +249,9 @@ class EduVPN(object): def in_fsm_state(self, state_id: State) -> bool: return self.go_function(lib.InFSMState, state_id) + + def get_saved_servers_old(self) -> str: + return self.go_function(lib.GetSavedServersOLD) + + def get_saved_servers_new(self) -> str: + return self.go_function_custom_decode(lib.GetSavedServersNEW, decode_func=get_servers) diff --git a/wrappers/python/src/server.py b/wrappers/python/src/server.py new file mode 100644 index 0000000..b765ede --- /dev/null +++ b/wrappers/python/src/server.py @@ -0,0 +1,158 @@ +from . import lib, cServers, cServerLocations +from ctypes import cast, POINTER + + +class Profile: + def __init__(self, identifier, display_name, default_gateway: bool): + self.identifier = identifier + self.display_name = display_name + self.default_gateway = default_gateway + + def __str__(self): + return f"Profile: {self.display_name}" + + +class Server: + def __init__(self, url, display_name, profiles, current_profile, expire_time): + self.url = url + self.display_name = display_name + self.profiles = profiles + self.current_profile = None + if current_profile < len(profiles): + self.current_profile = profiles[current_profile] + self.expire_time = expire_time + + def __str__(self): + return f"Server: {self.url}, with current profile: {self.current_profile}" + + +class InstituteServer(Server): + def __init__( + self, url, display_name, support_contact, profiles, current_profile, expire_time + ): + super().__init__(url, display_name, profiles, current_profile, expire_time) + self.support_contact = support_contact + + def __str__(self): + return f"Institute Server: {self.display_name}" + + +class SecureInternetServer(Server): + def __init__( + self, + url, + display_name, + support_contact, + profiles, + current_profile, + expire_time, + country_code, + ): + super().__init__(url, display_name, profiles, current_profile, expire_time) + self.support_contact = support_contact + self.country_code = country_code + + def __str__(self): + return f"Secure Internet Server: {self.display_name} with country {self.country_code}" + + +def get_type_for_str(type_str: str): + if type_str is "secure_internet": + return SecureInternetServer + if type_str is "custom_server": + return Server + return InstituteServer + + +def get_server(ptr, _type=None): + if not ptr: + return None + + current_server = ptr.contents + if _type is None: + _type = get_type_for_str(current_server.server_type.decode("utf-8")) + + identifier = current_server.identifier.decode("utf-8") + display_name = current_server.display_name.decode("utf-8") + + if _type is not Server: + support_contact = [] + for i in range(current_server.total_support_contact): + support_contact.append(current_server.support_contact[i].decode("utf-8")) + profiles = [] + if not current_server.profiles: + return None + + _profiles = current_server.profiles.contents + current_profile = _profiles.current + for i in range(_profiles.total_profiles): + if not _profiles.profiles or not _profiles.profiles[i]: + return None + profile = _profiles.profiles[i].contents + profiles.append( + Profile( + profile.identifier.decode("utf-8"), + profile.display_name.decode("utf-8"), + profile.default_gateway == 1, + ) + ) + + if _type is SecureInternetServer: + return SecureInternetServer( + identifier, + display_name, + support_contact, + profiles, + current_profile, + current_server.expire_time, + current_server.country_code.decode("utf-8"), + ) + if _type is InstituteServer: + return InstituteServer( + identifier, + display_name, + support_contact, + profiles, + current_profile, + current_server.expire_time, + ) + return Server( + identifier, display_name, profiles, current_profile, current_server.expire_time + ) + + +def get_servers(ptr): + if ptr: + returned = [] + servers = cast(ptr, POINTER(cServers)).contents + if servers.custom_servers: + for i in range(servers.total_custom): + current = get_server(servers.custom_servers[i], Server) + if current is None: + continue + returned.append(current) + + if servers.institute_servers: + for i in range(servers.total_institute): + current = get_server(servers.institute_servers[i], InstituteServer) + if current is None: + continue + returned.append(current) + + if servers.secure_internet: + current = get_server(servers.secure_internet, SecureInternetServer) + if current is not None: + returned.append(current) + lib.FreeServers(ptr) + return returned + return None + +def get_locations(ptr): + if ptr: + locations = cast(ptr, POINTER(cServerLocations)).contents + location_list = [] + for i in range(locations.total_locations): + location_list.append(locations.locations[i].decode("utf-8")) + lib.FreeSecureLocations(ptr) + return location_list + return None diff --git a/wrappers/python/tests.py b/wrappers/python/tests.py index 60d3cce..679eda0 100644 --- a/wrappers/python/tests.py +++ b/wrappers/python/tests.py @@ -24,7 +24,7 @@ class ConfigTests(unittest.TestCase): @_eduvpn.event.on(State.OAUTH_STARTED, StateType.Enter) def oauth_initialized(old_state, url_json): - login_eduvpn(json.loads(url_json)) + login_eduvpn(url_json) server_uri = os.getenv("SERVER_URI") if not server_uri: -- cgit v1.2.3