summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-14 13:56:49 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-14 13:56:49 +0200
commitda83f54606c9c1d2786d87074ee17ed972d2e1b2 (patch)
tree0be57934f9f467c87576abb0b457fb54b2d25d52
parentfd34e72da8c604517050ada7e883ba982829d985 (diff)
Refactor: Return without json
-rw-r--r--exports/Makefile12
-rw-r--r--exports/c/common.c6
-rw-r--r--exports/c/common.h3
-rw-r--r--exports/c/disco.h15
-rw-r--r--exports/c/servers.h43
-rw-r--r--exports/common.mk2
-rw-r--r--exports/disco.go97
-rw-r--r--exports/exports.go66
-rw-r--r--exports/servers.go273
-rw-r--r--fsm.go32
-rw-r--r--go.mod6
-rw-r--r--go.sum1
-rw-r--r--internal/discovery/discovery.go18
-rw-r--r--internal/fsm/fsm.go11
-rw-r--r--internal/oauth/oauth.go12
-rw-r--r--internal/server/custom.go1
-rw-r--r--internal/types/server.go28
-rw-r--r--internal/util/util.go49
-rw-r--r--internal/util/util_test.go52
-rw-r--r--state.go296
-rw-r--r--wrappers/python/main.py57
-rw-r--r--wrappers/python/src/__init__.py68
-rw-r--r--wrappers/python/src/discovery.py43
-rw-r--r--wrappers/python/src/event.py21
-rw-r--r--wrappers/python/src/main.py25
-rw-r--r--wrappers/python/src/server.py158
-rw-r--r--wrappers/python/tests.py2
27 files changed, 1225 insertions, 172 deletions
diff --git a/exports/Makefile b/exports/Makefile
index b833228..46d17a9 100644
--- a/exports/Makefile
+++ b/exports/Makefile
@@ -2,6 +2,8 @@
include common.mk
+CLIBPATH=./c
+
ifeq ($(LIB_SUFFIX),.so)
# Add SONAME as cgo does not currently do this. Mostly for Android, see https://stackoverflow.com/a/48291044
export override CGO_LDFLAGS += -Wl,-soname,$(LIB_FILE)
@@ -13,12 +15,18 @@ ifdef COPY_LIB_TO
install $< -Dt $(COPY_LIB_TO)
endif
+${CLIBPATH}/libcommon$(LIB_SUFFIX): ${CLIBPATH}/common.c
+ $(CC) -c -Wall -Werror -fpic -o ${CLIBPATH}/common.o ${CLIBPATH}/common.c
+ $(CC) -shared -o $@ ${CLIBPATH}/common.o
+
# Build shared library and remove lib prefix (if any) from header name
# GOOS and GOARCH envvars are set by common.mk
# This extra target prevents unnecessary rebuild
-lib/$(GOOS)/$(GOARCH)/$(LIB_FILE): exports.go ..
- CGO_ENABLED=1 go build -o $@ -buildmode=c-shared $<
+lib/$(GOOS)/$(GOARCH)/$(LIB_FILE): ${CLIBPATH}/libcommon$(LIB_SUFFIX) exports.go servers.go ..
+ CGO_ENABLED=1 go build -o $@ -buildmode=c-shared .
mv lib/$(GOOS)/$(GOARCH)/$(LIB_PREFIX)$(LIB_NAME).h lib/$(GOOS)/$(GOARCH)/$(LIB_NAME).h || true # Normalize header name
clean:
rm -rf ../exports/lib/*
+ rm -rf ${CLIBPATH}/common.o
+ rm -rf ${CLIBPATH}/libcommon.so
diff --git a/exports/c/common.c b/exports/c/common.c
new file mode 100644
index 0000000..425a459
--- /dev/null
+++ b/exports/c/common.c
@@ -0,0 +1,6 @@
+#include "common.h"
+
+void call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data)
+{
+ callback(name, oldstate, newstate, data);
+}
diff --git a/exports/c/common.h b/exports/c/common.h
new file mode 100644
index 0000000..068ad4c
--- /dev/null
+++ b/exports/c/common.h
@@ -0,0 +1,3 @@
+typedef void (*PythonCB)(const char* name, int oldstate, int newstate, void* data);
+
+void call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data);
diff --git a/exports/c/disco.h b/exports/c/disco.h
new file mode 100644
index 0000000..41d59fa
--- /dev/null
+++ b/exports/c/disco.h
@@ -0,0 +1,15 @@
+// for size_t
+#include <stddef.h>
+
+typedef struct discoveryOrganization {
+ const char* display_name;
+ const char* org_id;
+ const char* secure_internet_home;
+ const char* keyword_list;
+} discoveryOrganization;
+
+typedef struct discoveryOrganizations {
+ unsigned long long int version;
+ discoveryOrganization** organizations;
+ size_t total_organizations;
+} discoveryOrganizations;
diff --git a/exports/c/servers.h b/exports/c/servers.h
new file mode 100644
index 0000000..39e52a2
--- /dev/null
+++ b/exports/c/servers.h
@@ -0,0 +1,43 @@
+// for size_t
+#include <stddef.h>
+
+// The struct for a single server profile
+typedef struct serverProfile {
+ const char* id;
+ const char* display_name;
+ //const char* proto_list;
+ int default_gateway;
+} serverProfile;
+
+// The struct for all server profiles
+typedef struct serverProfiles {
+ int current;
+ serverProfile** profiles;
+ size_t total_profiles;
+} serverProfiles;
+
+// The struct for server locations
+typedef struct serverLocations {
+ const char** locations;
+ size_t total_locations;
+} serverLocations;
+
+// The struct for a single server
+typedef struct server {
+ const char* identifier;
+ const char* display_name;
+ const char* country_code;
+ const char** support_contact;
+ size_t total_support_contact;
+ serverProfiles* profiles;
+ unsigned long long int expire_time;
+} server;
+
+// The struct for all servers
+typedef struct servers {
+ server** custom_servers;
+ size_t total_custom;
+ server** institute_servers;
+ size_t total_institute;
+ server* secure_internet_server;
+} servers;
diff --git a/exports/common.mk b/exports/common.mk
index c1e2a7e..211d460 100644
--- a/exports/common.mk
+++ b/exports/common.mk
@@ -7,6 +7,8 @@ ifndef GOARCH
export GOARCH := $(shell go env GOHOSTARCH)
endif
+CC = gcc
+
ifeq (windows,$(GOOS))
LIB_PREFIX ?=
LIB_SUFFIX ?= .dll
diff --git a/exports/disco.go b/exports/disco.go
new file mode 100644
index 0000000..9ee2af9
--- /dev/null
+++ b/exports/disco.go
@@ -0,0 +1,97 @@
+package main
+
+/*
+// for free
+#include <stdlib.h>
+#include "c/disco.h"
+*/
+import "C"
+
+import (
+ "unsafe"
+
+ "github.com/jwijenbergh/eduvpn-common"
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
+)
+
+func getCPtrDiscoOrganization(
+ state *eduvpn.VPNState,
+ organization *types.DiscoveryOrganization,
+) *C.discoveryOrganization {
+ returnedStruct := (*C.discoveryOrganization)(
+ C.malloc(C.size_t(unsafe.Sizeof(C.discoveryOrganization{}))),
+ )
+ returnedStruct.display_name = C.CString(state.GetTranslated(organization.DisplayName))
+ returnedStruct.org_id = C.CString(organization.OrgId)
+ returnedStruct.secure_internet_home = C.CString(organization.SecureInternetHome)
+ returnedStruct.keyword_list = C.CString(state.GetTranslated(organization.KeywordList))
+ return returnedStruct
+}
+
+func getCPtrDiscoOrganizations(
+ state *eduvpn.VPNState,
+ organizations *types.DiscoveryOrganizations,
+) (C.size_t, **C.discoveryOrganization) {
+ totalOrganizations := C.size_t(len(organizations.List))
+ var organizationsPtr **C.discoveryOrganization
+ if totalOrganizations > 0 {
+ organizationsPtr = (**C.discoveryOrganization)(
+ C.malloc(totalOrganizations * C.size_t(unsafe.Sizeof(uintptr(0)))),
+ )
+ cOrganizations := (*[1<<30 - 1]*C.discoveryOrganization)(unsafe.Pointer(organizationsPtr))[:totalOrganizations:totalOrganizations]
+ index := 0
+ for _, organization := range organizations.List {
+ cOrganization := getCPtrDiscoOrganization(state, &organization)
+ cOrganizations[index] = cOrganization
+ index += 1
+ }
+ }
+ return totalOrganizations, organizationsPtr
+}
+
+func freeDiscoOrganization(cOrganization *C.discoveryOrganization) {
+ C.free(unsafe.Pointer(cOrganization.display_name))
+ C.free(unsafe.Pointer(cOrganization.org_id))
+ C.free(unsafe.Pointer(cOrganization.secure_internet_home))
+ C.free(unsafe.Pointer(cOrganization.keyword_list))
+ C.free(unsafe.Pointer(cOrganization))
+}
+
+//export FreeDiscoOrganizations
+func FreeDiscoOrganizations(cOrganizations *C.discoveryOrganizations) {
+ if cOrganizations.total_organizations > 0 {
+ organizations := (*[1<<30 - 1]*C.discoveryOrganization)(unsafe.Pointer(cOrganizations.organizations))[:cOrganizations.total_organizations:cOrganizations.total_organizations]
+ for i := C.size_t(0); i < cOrganizations.total_organizations; i++ {
+ freeDiscoOrganization(organizations[i])
+ }
+ C.free(unsafe.Pointer(cOrganizations.organizations))
+ }
+ C.free(unsafe.Pointer(cOrganizations))
+}
+
+//export GetDiscoOrganizations
+func GetDiscoOrganizations(name *C.char) *C.discoveryOrganizations {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ // TODO
+ if stateErr != nil {
+ panic(stateErr)
+ }
+ organizations, organizationsErr := state.GetDiscoOrganizations()
+ // TODO
+ if organizationsErr != nil {
+ panic(organizationsErr)
+ }
+
+ returnedStruct := (*C.discoveryOrganizations)(
+ C.malloc(C.size_t(unsafe.Sizeof(C.discoveryOrganizations{}))),
+ )
+
+ returnedStruct.version = C.ulonglong(organizations.Version)
+ returnedStruct.total_organizations, returnedStruct.organizations = getCPtrDiscoOrganizations(
+ state,
+ organizations,
+ )
+
+ return returnedStruct
+}
diff --git a/exports/exports.go b/exports/exports.go
index b4eb909..797a192 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -1,15 +1,13 @@
package main
/*
-#include <stdlib.h>
-
-typedef void (*PythonCB)(const char* name, int oldstate, int newstate, const char* data);
+#cgo CFLAGS: -I${SRCDIR}/c
+#cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/c
+#cgo LDFLAGS: -L${SRCDIR}/c
+#cgo LDFLAGS: -lcommon
-__attribute__((weak))
-void call_callback(PythonCB callback, const char *name, int oldstate, int newstate, const char* data)
-{
- callback(name, oldstate, newstate, data);
-}
+#include <stdlib.h>
+#include "c/common.h"
*/
import "C"
@@ -26,7 +24,28 @@ var P_StateCallbacks map[string]C.PythonCB
var VPNStates map[string]*eduvpn.VPNState
+func GetStateData(
+ state *eduvpn.VPNState,
+ stateID eduvpn.FSMStateID,
+ data interface{},
+) unsafe.Pointer {
+ switch stateID {
+ case eduvpn.STATE_NO_SERVER:
+ return (unsafe.Pointer)(getTransitionDataServers(state, data))
+ case eduvpn.STATE_OAUTH_STARTED:
+ if converted, ok := data.(string); ok {
+ return (unsafe.Pointer)(C.CString(converted))
+ }
+ case eduvpn.STATE_ASK_LOCATION:
+ return (unsafe.Pointer)(getTransitionSecureLocations(data))
+ default:
+ return nil
+ }
+ return nil
+}
+
func StateCallback(
+ state *eduvpn.VPNState,
name string,
old_state eduvpn.FSMStateID,
new_state eduvpn.FSMStateID,
@@ -39,18 +58,10 @@ func StateCallback(
name_c := C.CString(name)
oldState_c := C.int(old_state)
newState_c := C.int(new_state)
- data_json, jsonErr := json.Marshal(data)
- var dataJsonString string
- if jsonErr != nil {
- // TODO: How to handle error further? Log?
- dataJsonString = "{}"
- } else {
- dataJsonString = string(data_json)
- }
- data_c := C.CString(dataJsonString)
+ data_c := GetStateData(state, new_state, data)
C.call_callback(P_StateCallback, name_c, oldState_c, newState_c, data_c)
C.free(unsafe.Pointer(name_c))
- C.free(unsafe.Pointer(data_c))
+ // data_c gets freed by the wrapper
}
func GetVPNState(name string) (*eduvpn.VPNState, error) {
@@ -87,7 +98,7 @@ func Register(
nameStr,
C.GoString(config_directory),
func(old eduvpn.FSMStateID, new eduvpn.FSMStateID, data interface{}) {
- StateCallback(nameStr, old, new, data)
+ StateCallback(state, nameStr, old, new, data)
},
debug != 0,
)
@@ -150,7 +161,7 @@ func getConfigJSON(config string, configType string) *C.char {
}
//export RemoveSecureInternet
-func RemoveSecureInternet(name *C.char) (*C.char) {
+func RemoveSecureInternet(name *C.char) *C.char {
nameStr := C.GoString(name)
state, stateErr := GetVPNState(nameStr)
if stateErr != nil {
@@ -161,7 +172,7 @@ func RemoveSecureInternet(name *C.char) (*C.char) {
}
//export RemoveInstituteAccess
-func RemoveInstituteAccess(name *C.char, url *C.char) (*C.char) {
+func RemoveInstituteAccess(name *C.char, url *C.char) *C.char {
nameStr := C.GoString(name)
state, stateErr := GetVPNState(nameStr)
if stateErr != nil {
@@ -172,7 +183,7 @@ func RemoveInstituteAccess(name *C.char, url *C.char) (*C.char) {
}
//export RemoveCustomServer
-func RemoveCustomServer(name *C.char, url *C.char) (*C.char) {
+func RemoveCustomServer(name *C.char, url *C.char) *C.char {
nameStr := C.GoString(name)
state, stateErr := GetVPNState(nameStr)
if stateErr != nil {
@@ -218,17 +229,6 @@ func GetConfigCustomServer(name *C.char, url *C.char, forceTCP C.int) (*C.char,
return getConfigJSON(config, configType), C.CString(ErrorToString(configErr))
}
-//export GetDiscoOrganizations
-func GetDiscoOrganizations(name *C.char) (*C.char, *C.char) {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return nil, C.CString(ErrorToString(stateErr))
- }
- organizations, organizationsErr := state.GetDiscoOrganizations()
- return C.CString(organizations), C.CString(ErrorToString(organizationsErr))
-}
-
//export GetDiscoServers
func GetDiscoServers(name *C.char) (*C.char, *C.char) {
nameStr := C.GoString(name)
diff --git a/exports/servers.go b/exports/servers.go
new file mode 100644
index 0000000..f92c08e
--- /dev/null
+++ b/exports/servers.go
@@ -0,0 +1,273 @@
+package main
+
+/*
+// for free
+#include <stdlib.h>
+#include "c/servers.h"
+*/
+import "C"
+
+import (
+ "unsafe"
+
+ "github.com/jwijenbergh/eduvpn-common"
+ "github.com/jwijenbergh/eduvpn-common/internal/server"
+)
+
+// Get the pointer to the C struct for the profile
+// We allocate the struct, the profile ID and the display name
+func getCPtrProfile(profile *server.ServerProfile) *C.serverProfile {
+ // Allocate the struct using malloc and the size of the struct
+ cProfile := (*C.serverProfile)(C.malloc(C.size_t(unsafe.Sizeof(C.serverProfile{}))))
+ cProfile.id = C.CString(profile.ID)
+ cProfile.display_name = C.CString(profile.DisplayName)
+ if profile.DefaultGateway {
+ cProfile.default_gateway = C.int(1)
+ } else {
+ cProfile.default_gateway = C.int(0)
+ }
+
+ return cProfile
+}
+
+// Get the pointer to the C struct for the profiles
+// We allocate the struct and the struct inside it for the profiles
+func getCPtrProfiles(serverProfiles *server.ServerProfileInfo) *C.serverProfiles {
+ goProfiles := serverProfiles.Info.ProfileList
+ // Allocate the profles struct using malloc and the size of a pointer
+ cProfiles := (*C.serverProfiles)(C.malloc(C.size_t(uintptr(0))))
+ totalProfiles := C.size_t(len(goProfiles))
+ // Defaults if we have no profiles
+ cProfiles.current = C.int(0)
+ cProfiles.profiles = nil
+ cProfiles.total_profiles = totalProfiles
+ // If we have profiles (which we should), we allocate the struct with malloc and the size of a pointer
+ // We then fill the struct by converting it to a go slice and get a C pointer for each profile
+ if totalProfiles > 0 {
+ profilesPtr := C.malloc(totalProfiles * C.size_t(unsafe.Sizeof(uintptr(0))))
+ profiles := (*[1<<30 - 1]*C.serverProfile)(unsafe.Pointer(profilesPtr))[:totalProfiles:totalProfiles]
+ index := 0
+ for _, profile := range goProfiles {
+ profiles[index] = getCPtrProfile(&profile)
+ index += 1
+ }
+ // TODO: DO CURRENT PROFILE
+ cProfiles.profiles = (**C.serverProfile)(profilesPtr)
+ }
+ return cProfiles
+}
+
+// Free the profiles by looping through them if there are any
+// Also free the pointer itself
+func freeCProfiles(profiles *C.serverProfiles) {
+ // We should only free the profiles if we have them (which we should)
+ if profiles.total_profiles > 0 {
+ // Convert it to a go slice
+ profilesSlice := (*[1<<30 - 1]*C.serverProfile)(unsafe.Pointer(profiles.profiles))[:profiles.total_profiles:profiles.total_profiles]
+ // Loop through the pointers and free th allocated strings and the struct itself
+ for i := C.size_t(0); i < profiles.total_profiles; i++ {
+ C.free(unsafe.Pointer(profilesSlice[i].id))
+ C.free(unsafe.Pointer(profilesSlice[i].display_name))
+ C.free(unsafe.Pointer(profilesSlice[i]))
+ }
+ // Free the inner profiles struct
+ C.free(unsafe.Pointer(profiles.profiles))
+ }
+ // Free the profiles struct itself
+ C.free(unsafe.Pointer(profiles))
+}
+
+// Get a list of strings with a size as a c structure
+// Returns the size in size_t and the list of strings as a double pointer char
+func getCPtrListStrings(allStrings []string) (C.size_t, **C.char) {
+ // Get the total strings in size_t
+ totalStrings := C.size_t(len(allStrings))
+
+ // If we have strings
+ // Allocate memory for the strings array
+ if totalStrings > 0 {
+ stringsPtr := C.malloc(totalStrings * C.size_t(unsafe.Sizeof(uintptr(0))))
+ // Go slice conversion
+ cStrings := (*[1<<30 - 1]*C.char)(unsafe.Pointer(stringsPtr))[:totalStrings:totalStrings]
+
+ // Loop through and allocate the string for each contact
+ for index, string := range allStrings {
+ cStrings[index] = C.CString(string)
+ }
+ return totalStrings, (**C.char)(stringsPtr)
+ }
+
+ // No strings then the length is zero and the char array is nil
+ return C.size_t(0), nil
+}
+
+// Function for freeing an array/list of strings
+// It takes the strings as a pointer to a string and the total strings in size_t
+func freeCListStrings(allStrings **C.char, totalStrings C.size_t) {
+ // If we have strings we should free them
+ // By converting to a Go slice, and freeing them ony by one
+ // At last free the pointer itself
+ if totalStrings > 0 {
+ stringsSlice := (*[1<<30 - 1]*C.char)(unsafe.Pointer(allStrings))[:totalStrings:totalStrings]
+ for i := C.size_t(0); i < totalStrings; i++ {
+ C.free(unsafe.Pointer(stringsSlice[i]))
+ }
+ C.free(unsafe.Pointer(allStrings))
+ }
+}
+
+// Function for getting the server,
+// It gets the main state as a pointer as we need to convert some string maps to localized strings
+// It gets the base information for a server as well
+func getServer(state *eduvpn.VPNState, base *eduvpn.VPNServerBase) *C.server {
+ // Allocation using malloc and the size of the struct
+ server := (*C.server)(C.malloc(C.size_t(unsafe.Sizeof(C.server{}))))
+ // String allocation and translate the display name
+ server.identifier = C.CString(base.URL)
+ server.display_name = C.CString(state.GetTranslated(base.DisplayName))
+ // Call the helper to get the list of support contacts
+ server.total_support_contact, server.support_contact = getCPtrListStrings(
+ base.SupportContact,
+ )
+ server.profiles = getCPtrProfiles(&base.Profiles)
+ // No endtime is given if we get servers when it has been partially initialised
+ if base.EndTime.IsZero() {
+ server.expire_time = C.ulonglong(0)
+ }
+ // The expire time should be stored as an unsigned long long in unix itme
+ server.expire_time = C.ulonglong(base.EndTime.Unix())
+ return server
+}
+
+// Function for freeing a single server
+// Gets the pointer to C struct
+func freeServer(info *C.server) {
+ // Free strings
+ C.free(unsafe.Pointer(info.identifier))
+ C.free(unsafe.Pointer(info.display_name))
+
+ // Free arrays
+ freeCListStrings(info.support_contact, info.total_support_contact)
+ freeCProfiles(info.profiles)
+
+ // Free the struct itself
+ C.free(unsafe.Pointer(info))
+}
+
+// Get the C ptr to the servers, returns the length in size_t and the double pointer to the struct
+func getCPtrServers(
+ state *eduvpn.VPNState,
+ serverMap map[string]*server.InstituteAccessServer,
+) (C.size_t, **C.server) {
+ totalServers := C.size_t(len(serverMap))
+ // If we have servers, which is not always the case
+ if totalServers > 0 {
+ serversPtr := (**C.server)(C.malloc(totalServers * C.size_t(unsafe.Sizeof(uintptr(0)))))
+ servers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(serversPtr))[:totalServers:totalServers]
+ index := 0
+ for _, server := range serverMap {
+ cServer := getServer(state, &server.Base)
+ servers[index] = cServer
+ index += 1
+ }
+ }
+ return C.size_t(0), nil
+}
+
+//export FreeServers
+// This function takes the servers as a C struct pointer as input
+// It frees all allocated memory for the server
+func FreeServers(cServers *C.servers) {
+ // Free the custom servers if there are any
+ if cServers.total_custom > 0 {
+ customServers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(cServers.custom_servers))[:cServers.total_custom:cServers.total_custom]
+ for i := C.size_t(0); i < cServers.total_custom; i++ {
+ freeServer(customServers[i])
+ }
+ C.free(unsafe.Pointer(cServers.custom_servers))
+ }
+ // Free the institute access servers if there are any
+ if cServers.total_institute > 0 {
+ instituteServers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(cServers.institute_servers))[:cServers.total_institute:cServers.total_institute]
+
+ for i := C.size_t(0); i < cServers.total_institute; i++ {
+ freeServer(instituteServers[i])
+ }
+ C.free(unsafe.Pointer(cServers.institute_servers))
+ }
+ // Free the secure internet server if there is one
+ if cServers.secure_internet_server != nil {
+ C.free(unsafe.Pointer(cServers.secure_internet_server.country_code))
+ freeServer(cServers.secure_internet_server)
+ }
+ // Free the structure itself
+ C.free(unsafe.Pointer(cServers))
+}
+
+// Return the servers as a C struct pointer
+// It takes the state as a pointer as we need to translate some strings
+// It also takes the servers as a pointer that belongs to the main state or gathered from the callback
+func getSavedServersWithOptions(state *eduvpn.VPNState, servers *server.Servers) *C.servers {
+ // Allocate the struct that we will return
+ // With the size of the c struct
+ returnedStruct := (*C.servers)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{}))))
+
+ // Get the different categories of servers
+ totalCustom, customPtr := getCPtrServers(state, servers.CustomServers.Map)
+ totalInstitute, institutePtr := getCPtrServers(state, servers.InstituteServers.Map)
+ var secureServerPtr *C.server = nil
+ secureInternetBase, secureInternetBaseErr := servers.SecureInternetHomeServer.GetBase()
+ if secureInternetBaseErr == nil && secureInternetBase != nil {
+ // FIXME: log error?
+ secureServerPtr = getServer(state, secureInternetBase)
+ // Give a new identifier
+ C.free(unsafe.Pointer(secureServerPtr.identifier))
+ secureServerPtr.identifier = C.CString(servers.SecureInternetHomeServer.HomeOrganizationID)
+ secureServerPtr.country_code = C.CString(servers.SecureInternetHomeServer.CurrentLocation)
+ }
+
+ // Fill the struct and return
+ returnedStruct.custom_servers = customPtr
+ returnedStruct.total_custom = totalCustom
+ returnedStruct.institute_servers = institutePtr
+ returnedStruct.total_institute = totalInstitute
+ returnedStruct.secure_internet_server = secureServerPtr
+ return returnedStruct
+}
+
+//export GetSavedServers
+// This function takes the name as input which is the name of the client
+// It gets the state by name and then returns the saved servers as a c struct belonging to it
+func GetSavedServers(name *C.char) *C.servers {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ if stateErr != nil {
+ // TODO: Remove this panic
+ panic(stateErr)
+ }
+ return getSavedServersWithOptions(state, &state.Servers)
+}
+
+// This function takes the state as input which is the main state
+// It also takes the data as an interface and if it has the servers type gets the data as a c struct otherwise nil
+func getTransitionDataServers(state *eduvpn.VPNState, data interface{}) *C.servers {
+ if converted, ok := data.(server.Servers); ok {
+ return getSavedServersWithOptions(state, &converted)
+ }
+ return nil
+}
+
+//export FreeSecureLocations
+func FreeSecureLocations(locations *C.serverLocations) {
+ freeCListStrings(locations.locations, locations.total_locations);
+ C.free(unsafe.Pointer(locations))
+}
+
+func getTransitionSecureLocations(data interface{}) (*C.serverLocations) {
+ if locations, ok := data.([]string); ok {
+ returnedStruct := (*C.serverLocations)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{}))))
+ returnedStruct.total_locations, returnedStruct.locations = getCPtrListStrings(locations)
+ return returnedStruct
+ }
+ return nil
+}
diff --git a/fsm.go b/fsm.go
index 31ed30b..a09e335 100644
--- a/fsm.go
+++ b/fsm.go
@@ -3,14 +3,17 @@ package eduvpn
import (
"errors"
"fmt"
- "github.com/jwijenbergh/eduvpn-common/internal/fsm"
- "github.com/jwijenbergh/eduvpn-common/internal/types"
+
+ "github.com/jwijenbergh/eduvpn-common/internal/fsm"
+ "github.com/jwijenbergh/eduvpn-common/internal/types"
)
-type FSMStateID = fsm.FSMStateID
-type FSMStates = fsm.FSMStates
-type FSMState = fsm.FSMState
-type FSMTransition = fsm.FSMTransition
+type (
+ FSMStateID = fsm.FSMStateID
+ FSMStates = fsm.FSMStates
+ FSMState = fsm.FSMState
+ FSMTransition = fsm.FSMTransition
+)
const (
// Deregistered means the app is not registered with the wrapper
@@ -91,9 +94,16 @@ func GetStateName(s FSMStateID) string {
}
}
-func newFSM(name string, callback func(FSMStateID, FSMStateID, interface{}), directory string, debug bool) fsm.FSM {
+func newFSM(
+ name string,
+ callback func(FSMStateID, FSMStateID, interface{}),
+ directory string,
+ debug bool,
+) fsm.FSM {
states := FSMStates{
- STATE_DEREGISTERED: FSMState{Transitions: []FSMTransition{{STATE_NO_SERVER, "Client registers"}}},
+ STATE_DEREGISTERED: FSMState{
+ Transitions: []FSMTransition{{STATE_NO_SERVER, "Client registers"}},
+ },
STATE_NO_SERVER: FSMState{
Transitions: []FSMTransition{
{STATE_NO_SERVER, "Reload list"},
@@ -231,7 +241,11 @@ func (e FSMWrongStateError) CustomError() *types.WrappedErrorMessage {
return &types.WrappedErrorMessage{
Message: "Wrong FSM State",
Err: errors.New(
- fmt.Sprintf("wrong FSM state, got: %s, want: %s", GetStateName(e.Got), GetStateName(e.Want)),
+ fmt.Sprintf(
+ "wrong FSM state, got: %s, want: %s",
+ GetStateName(e.Got),
+ GetStateName(e.Want),
+ ),
),
}
}
diff --git a/go.mod b/go.mod
index 88bb6f1..30b1d70 100644
--- a/go.mod
+++ b/go.mod
@@ -4,6 +4,10 @@ go 1.15
require (
github.com/jedisct1/go-minisign v0.0.0-20211028175153-1c139d1cc84b
- golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 // indirect
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20220420130459-88a4932fb60b
)
+
+require (
+ golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 // indirect
+ golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 // indirect
+)
diff --git a/go.sum b/go.sum
index d536cee..6ecd05e 100644
--- a/go.sum
+++ b/go.sum
@@ -43,7 +43,6 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard v0.0.0-20220407013110-ef5c587f782d h1:q4JksJ2n0fmbXC0Aj0eOs6E0AcPqnKglxWXWFqGD6x0=
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index 3ab13b3..21125cb 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -163,22 +163,22 @@ func (discovery *Discovery) DetermineServersUpdate() bool {
}
// Get the organization list
-func (discovery *Discovery) GetOrganizationsList() (string, error) {
+func (discovery *Discovery) GetOrganizationsList() (*types.DiscoveryOrganizations, error) {
if !discovery.DetermineOrganizationsUpdate() {
- return discovery.Organizations.RawString, nil
+ return &discovery.Organizations, nil
}
file := "organization_list.json"
body, bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations)
if bodyErr != nil {
// Return previous with an error
- return discovery.Organizations.RawString, &types.WrappedErrorMessage{
+ return &discovery.Organizations, &types.WrappedErrorMessage{
Message: "failed getting organizations in Discovery",
Err: bodyErr,
}
}
discovery.Organizations.RawString = body
discovery.Organizations.Timestamp = util.GetCurrentTime()
- return discovery.Organizations.RawString, nil
+ return &discovery.Organizations, nil
}
// Get the server list
@@ -206,7 +206,10 @@ type GetOrgByIDNotFoundError struct {
}
func (e GetOrgByIDNotFoundError) Error() string {
- return fmt.Sprintf("No Secure Internet Home found in organizations with ID %s. Please choose your server again", e.ID)
+ return fmt.Sprintf(
+ "No Secure Internet Home found in organizations with ID %s. Please choose your server again",
+ e.ID,
+ )
}
type GetServerByURLNotFoundError struct {
@@ -240,5 +243,8 @@ type GetSecureHomeArgsNotFoundError struct {
}
func (e GetSecureHomeArgsNotFoundError) Error() string {
- return fmt.Sprintf("No Secure Internet Home found with URL: %s. Please choose your server again", e.URL)
+ return fmt.Sprintf(
+ "No Secure Internet Home found with URL: %s. Please choose your server again",
+ e.URL,
+ )
}
diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go
index 63c9ac2..292e09e 100644
--- a/internal/fsm/fsm.go
+++ b/internal/fsm/fsm.go
@@ -1,9 +1,9 @@
package fsm
import (
+ "fmt"
"os"
"os/exec"
- "fmt"
"path"
"sort"
)
@@ -149,7 +149,13 @@ func (fsm *FSM) generateMermaidGraph() string {
} else {
graph += "\nstyle " + fsm.GetName(state) + " fill:white\n"
}
- graph += fsm.GetName(state) + "(" + fsm.GetName(state) + ") " + "-->|" + transition.Description + "| " + fsm.GetName(transition.To) + "\n"
+ graph += fsm.GetName(
+ state,
+ ) + "(" + fsm.GetName(
+ state,
+ ) + ") " + "-->|" + transition.Description + "| " + fsm.GetName(
+ transition.To,
+ ) + "\n"
}
}
return graph
@@ -162,4 +168,3 @@ func (fsm *FSM) GenerateGraph() string {
return ""
}
-
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 89854f7..10abf4a 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -315,7 +315,10 @@ func (oauth *OAuth) EnsureTokens() error {
errorMessage := "failed ensuring OAuth tokens"
// Access Token or Refresh Tokens empty, we can not ensure the tokens
if oauth.Token.Access == "" && oauth.Token.Refresh == "" {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: "tokens are empty"}}
+ return &types.WrappedErrorMessage{
+ Message: errorMessage,
+ Err: &OAuthTokensInvalidError{Cause: "tokens are empty"},
+ }
}
// We have tokens...
@@ -330,7 +333,12 @@ func (oauth *OAuth) EnsureTokens() error {
// We have obtained new tokens with refresh
if refreshErr != nil {
// We have failed to ensure the tokens due to refresh not working
- return &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthTokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr)}}
+ return &types.WrappedErrorMessage{
+ Message: errorMessage,
+ Err: &OAuthTokensInvalidError{
+ Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr),
+ },
+ }
}
return nil
diff --git a/internal/server/custom.go b/internal/server/custom.go
index c757f76..a93242d 100644
--- a/internal/server/custom.go
+++ b/internal/server/custom.go
@@ -3,4 +3,3 @@ package server
func (servers *Servers) RemoveCustomServer(url string) {
servers.CustomServers.Remove(url)
}
-
diff --git a/internal/types/server.go b/internal/types/server.go
index a9c46b2..33a6e9c 100644
--- a/internal/types/server.go
+++ b/internal/types/server.go
@@ -17,12 +17,10 @@ type DiscoveryOrganizations struct {
}
type DiscoveryOrganization struct {
- DisplayName map[string]string `json:"display_name"`
- OrgId string `json:"org_id"`
- SecureInternetHome string `json:"secure_internet_home"`
- KeywordList struct {
- En string `json:"en"`
- } `json:"keyword_list"`
+ DisplayName DiscoMapOrString `json:"display_name"`
+ OrgId string `json:"org_id"`
+ SecureInternetHome string `json:"secure_internet_home"`
+ KeywordList DiscoMapOrString `json:"keyword_list"`
}
// Structs that define the json format for
@@ -34,11 +32,11 @@ type DiscoveryServers struct {
RawString string `json:"go_raw_string"`
}
-type DNMapOrString map[string]string
+type DiscoMapOrString map[string]string
// The display name can either be a map or a string in the server list
// Unmarshal it by first trying a string and then the map
-func (DN *DNMapOrString) UnmarshalJSON(data []byte) error {
+func (DN *DiscoMapOrString) UnmarshalJSON(data []byte) error {
var displayNameString string
err := json.Unmarshal(data, &displayNameString)
@@ -60,11 +58,11 @@ func (DN *DNMapOrString) UnmarshalJSON(data []byte) error {
}
type DiscoveryServer struct {
- AuthenticationURLTemplate string `json:"authentication_url_template"`
- BaseURL string `json:"base_url"`
- CountryCode string `json:"country_code"`
- DisplayName DNMapOrString `json:"display_name,omitempty"`
- PublicKeyList []string `json:"public_key_list"`
- Type string `json:"server_type"`
- SupportContact []string `json:"support_contact"`
+ AuthenticationURLTemplate string `json:"authentication_url_template"`
+ BaseURL string `json:"base_url"`
+ CountryCode string `json:"country_code"`
+ DisplayName DiscoMapOrString `json:"display_name,omitempty"`
+ PublicKeyList []string `json:"public_key_list"`
+ Type string `json:"server_type"`
+ SupportContact []string `json:"support_contact"`
}
diff --git a/internal/util/util.go b/internal/util/util.go
index b104476..e652779 100644
--- a/internal/util/util.go
+++ b/internal/util/util.go
@@ -83,3 +83,52 @@ func ReplaceWAYF(authTemplate string, authURL string, orgID string) string {
authTemplate = strings.Replace(authTemplate, "@ORG_ID@", WAYFEncode(orgID), 1)
return authTemplate
}
+
+// https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY.md#language-matching
+func GetLanguageMatched(languageMap map[string]string, languageTag string) string {
+ // If no or empty map is given, return the empty string
+ if languageMap == nil || len(languageMap) == 0 {
+ return ""
+ }
+ // Try to find the exact match
+ if val, ok := languageMap[languageTag]; ok {
+ return val
+ }
+ // Try to find a key that starts with the OS language setting
+ for k := range languageMap {
+ if strings.HasPrefix(k, languageTag) {
+ return languageMap[k]
+ }
+ }
+ // Try to find a key that starts with the first part of the OS language (e.g. de-)
+ splitted := strings.Split(languageTag, "-")
+ // We have a "-"
+ if len(splitted) > 1 {
+ for k := range languageMap {
+ if strings.HasPrefix(k, splitted[0]+"-") {
+ return languageMap[k]
+ }
+ }
+ }
+ // search for just the language (e.g. de)
+ for k := range languageMap {
+ if k == splitted[0] {
+ return languageMap[k]
+ }
+ }
+
+ // Pick one that is deemed best, e.g. en-US or en, but note that not all languages are always available!
+ // We force an entry that is english exactly or with an english prefix
+ for k := range languageMap {
+ if k == "en" || strings.HasPrefix(k, "en-") {
+ return languageMap[k]
+ }
+ }
+
+ // Otherwise just return one
+ for k := range languageMap {
+ return languageMap[k]
+ }
+
+ return ""
+}
diff --git a/internal/util/util_test.go b/internal/util/util_test.go
index 9eda14e..31d01e7 100644
--- a/internal/util/util_test.go
+++ b/internal/util/util_test.go
@@ -80,7 +80,11 @@ func Test_WAYFEncode(t *testing.T) {
func Test_ReplaceWAYF(t *testing.T) {
// We expect url encoding but the spaces to be correctly replace with a + instead of a %20
// And we expect that the return to and org_id are correctly replaced
- replaced := ReplaceWAYF("@RETURN_TO@@ORG_ID@", "127.0.0.1:8000/&%$3#kM_- ", "idp-test.nl.org/")
+ replaced := ReplaceWAYF(
+ "@RETURN_TO@@ORG_ID@",
+ "127.0.0.1:8000/&%$3#kM_- ",
+ "idp-test.nl.org/",
+ )
wantReplaced := "127.0.0.1%3A8000%2F%26%25%243%23kM_-++++++++++++idp-test.nl.org%2F"
if replaced != wantReplaced {
t.Fatalf("Got: %s, want: %s", replaced, wantReplaced)
@@ -114,3 +118,49 @@ func Test_ReplaceWAYF(t *testing.T) {
t.Fatalf("Got: %s, want: %s", replaced, wantReplaced)
}
}
+
+func Test_GetLanguageMatched(t *testing.T) {
+ // func GetLanguageMatched(languageMap map[string]string, languageTag string) string {
+
+ // exact match
+ returned := GetLanguageMatched(map[string]string{"en": "test", "de": "test2"}, "en")
+ if returned != "test" {
+ t.Fatalf("Got: %s, want: %s", returned, "test")
+ }
+
+ // starts with language tag
+ returned = GetLanguageMatched(map[string]string{"en-US-test": "test", "de": "test2"}, "en-US")
+ if returned != "test" {
+ t.Fatalf("Got: %s, want: %s", returned, "test")
+ }
+
+ // starts with en-
+ returned = GetLanguageMatched(map[string]string{"en-UK": "test", "en": "test2"}, "en-US")
+ if returned != "test" {
+ t.Fatalf("Got: %s, want: %s", returned, "test")
+ }
+
+ // exact match for en
+ returned = GetLanguageMatched(map[string]string{"de": "test", "en": "test2"}, "en-US")
+ if returned != "test2" {
+ t.Fatalf("Got: %s, want: %s", returned, "test2")
+ }
+
+ // We default to english
+ returned = GetLanguageMatched(map[string]string{"es": "test", "en": "test2"}, "nl-NL")
+ if returned != "test2" {
+ t.Fatalf("Got: %s, want: %s", returned, "test2")
+ }
+
+ // We default to english with a - as well
+ returned = GetLanguageMatched(map[string]string{"est": "test", "en-": "test2"}, "en-US")
+ if returned != "test2" {
+ t.Fatalf("Got: %s, want: %s", returned, "test2")
+ }
+
+ // None found just return one
+ returned = GetLanguageMatched(map[string]string{"es": "test"}, "en-US")
+ if returned != "test" {
+ t.Fatalf("Got: %s, want: %s", returned, "test")
+ }
+}
diff --git a/state.go b/state.go
index 1354ffc..4911ac9 100644
--- a/state.go
+++ b/state.go
@@ -14,9 +14,15 @@ import (
"github.com/jwijenbergh/eduvpn-common/internal/util"
)
-type ServerInfo = server.ServerInfoScreen
+type (
+ ServerInfo = server.ServerInfoScreen
+ VPNServerBase = server.ServerBase
+)
type VPNState struct {
+ // The language used for language matching
+ Language string `json:"-"` // language should not be saved
+
// The chosen server
Servers server.Servers `json:"servers"`
@@ -40,6 +46,7 @@ func (state *VPNState) GetSavedServers() *server.ServersConfiguredScreen {
return state.Servers.GetServersConfigured()
}
+// TODO: Refactor this to a `New` method?
func (state *VPNState) Register(
name string,
directory string,
@@ -55,6 +62,7 @@ func (state *VPNState) Register(
}
// Initialize the logger
logLevel := log.LOG_WARNING
+ state.Language = "en"
if debug {
logLevel = log.LOG_INFO
@@ -83,18 +91,28 @@ func (state *VPNState) Register(
_, currentServerErr := state.Servers.GetCurrentServer()
// Only actually return the error if we have no disco servers and no current server
if discoServersErr != nil && discoServers == "" && currentServerErr != nil {
- state.Logger.Error(fmt.Sprintf("No configured servers, discovery servers is empty and no servers with error: %s", GetErrorTraceback(discoServersErr)))
+ state.Logger.Error(
+ fmt.Sprintf(
+ "No configured servers, discovery servers is empty and no servers with error: %s",
+ GetErrorTraceback(discoServersErr),
+ ),
+ )
return &types.WrappedErrorMessage{Message: errorMessage, Err: discoServersErr}
}
discoOrgs, discoOrgsErr := state.GetDiscoOrganizations()
// Only actually return the error if we have no disco organizations and no current server
- if discoOrgsErr != nil && discoOrgs == "" && currentServerErr != nil {
- state.Logger.Error(fmt.Sprintf("No configured organizations, discovery organizations empty and no servers with error: %s", GetErrorTraceback(discoOrgsErr)))
+ if discoOrgsErr != nil && discoOrgs.Version == 0 && currentServerErr != nil {
+ state.Logger.Error(
+ fmt.Sprintf(
+ "No configured organizations, discovery organizations empty and no servers with error: %s",
+ GetErrorTraceback(discoOrgsErr),
+ ),
+ )
return &types.WrappedErrorMessage{Message: errorMessage, Err: discoOrgsErr}
}
// Go to the No Server state with the saved servers
- state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), true)
+ state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, true)
return nil
}
@@ -121,7 +139,7 @@ func (state *VPNState) GoBack() error {
}
// FIXME: Abitrary back transitions don't work because we need the approriate data
- state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), false)
+ state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, false)
// state.FSM.GoBack()
return nil
}
@@ -157,7 +175,10 @@ func (state *VPNState) ensureLogin(chosenServer server.Server) error {
return nil
}
-func (state *VPNState) getConfigAuth(chosenServer server.Server, forceTCP bool) (string, string, error) {
+func (state *VPNState) getConfigAuth(
+ chosenServer server.Server,
+ forceTCP bool,
+) (string, string, error) {
loginErr := state.ensureLogin(chosenServer)
if loginErr != nil {
return "", "", loginErr
@@ -181,7 +202,10 @@ func (state *VPNState) getConfigAuth(chosenServer server.Server, forceTCP bool)
return server.GetConfig(chosenServer, forceTCP)
}
-func (state *VPNState) retryConfigAuth(chosenServer server.Server, forceTCP bool) (string, string, error) {
+func (state *VPNState) retryConfigAuth(
+ chosenServer server.Server,
+ forceTCP bool,
+) (string, string, error) {
errorMessage := "failed authorized config retry"
config, configType, configErr := state.getConfigAuth(chosenServer, forceTCP)
if configErr != nil {
@@ -189,10 +213,16 @@ func (state *VPNState) retryConfigAuth(chosenServer server.Server, forceTCP bool
// Only retry if the error is that the tokens are invalid
if errors.As(configErr, &error) {
- retryConfig, retryConfigType, retryConfigErr := state.getConfigAuth(chosenServer, forceTCP)
+ retryConfig, retryConfigType, retryConfigErr := state.getConfigAuth(
+ chosenServer,
+ forceTCP,
+ )
if retryConfigErr != nil {
state.GoBack()
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: retryConfigErr}
+ return "", "", &types.WrappedErrorMessage{
+ Message: errorMessage,
+ Err: retryConfigErr,
+ }
}
return retryConfig, retryConfigType, nil
}
@@ -202,7 +232,6 @@ func (state *VPNState) retryConfigAuth(chosenServer server.Server, forceTCP bool
return config, configType, nil
}
-
func (state *VPNState) getConfig(
chosenServer server.Server,
forceTCP bool,
@@ -221,8 +250,13 @@ func (state *VPNState) getConfig(
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
}
+ currentServer, currentServerErr := state.Servers.GetCurrentServer()
+ if currentServerErr != nil {
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
+ }
+
// Signal the server display info
- state.FSM.GoTransitionWithData(STATE_DISCONNECTED, state.getServerInfoData(), false)
+ state.FSM.GoTransitionWithData(STATE_DISCONNECTED, currentServer, false)
// Save the config
state.Config.Save(&state)
@@ -235,14 +269,25 @@ func (state *VPNState) SetSecureLocation(countryCode string) error {
server, serverErr := state.Discovery.GetServerByCountryCode(countryCode, "secure_internet")
if serverErr != nil {
- state.Logger.Error(fmt.Sprintf("Failed getting secure internet server by country code: %s with error: %s", countryCode, GetErrorTraceback(serverErr)))
+ state.Logger.Error(
+ fmt.Sprintf(
+ "Failed getting secure internet server by country code: %s with error: %s",
+ countryCode,
+ GetErrorTraceback(serverErr),
+ ),
+ )
state.GoBack()
return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
setLocationErr := state.Servers.SetSecureLocation(server)
if setLocationErr != nil {
- state.Logger.Error(fmt.Sprintf("Failed setting secure internet server with error: %s", GetErrorTraceback(serverErr)))
+ state.Logger.Error(
+ fmt.Sprintf(
+ "Failed setting secure internet server with error: %s",
+ GetErrorTraceback(serverErr),
+ ),
+ )
state.GoBack()
return &types.WrappedErrorMessage{Message: errorMessage, Err: setLocationErr}
}
@@ -267,7 +312,10 @@ func (state *VPNState) askSecureLocation() error {
// The state has changed, meaning setting the secure location was not successful
if state.FSM.Current != STATE_ASK_LOCATION {
// TODO: maybe a custom type for this errors.new?
- return &types.WrappedErrorMessage{Message: "failed setting secure location", Err: errors.New("failed loading secure location")}
+ return &types.WrappedErrorMessage{
+ Message: "failed setting secure location",
+ Err: errors.New("failed loading secure location"),
+ }
}
return nil
}
@@ -316,7 +364,7 @@ func (state *VPNState) RemoveSecureInternet() error {
}
// No error because we can only have one secure internet server and if there are no secure internet servers, this is a NO-OP
state.Servers.RemoveSecureInternet()
- state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), false)
+ state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, false)
// Save the config
state.Config.Save(&state)
return nil
@@ -331,7 +379,7 @@ func (state *VPNState) RemoveInstituteAccess(url string) error {
}
// No error because this is a NO-OP if the server doesn't exist
state.Servers.RemoveInstituteAccess(url)
- state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), false)
+ state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, false)
// Save the config
state.Config.Save(&state)
return nil
@@ -346,7 +394,7 @@ func (state *VPNState) RemoveCustomServer(url string) error {
}
// No error because this is a NO-OP if the server doesn't exist
state.Servers.RemoveCustomServer(url)
- state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), false)
+ state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, false)
// Save the config
state.Config.Save(&state)
return nil
@@ -363,7 +411,12 @@ func (state *VPNState) GetConfigSecureInternet(
state.FSM.GoTransition(STATE_LOADING_SERVER)
server, serverErr := state.addSecureInternetHomeServer(orgID)
if serverErr != nil {
- state.Logger.Error(fmt.Sprintf("Failed adding a secure internet server with error: %s", GetErrorTraceback(serverErr)))
+ state.Logger.Error(
+ fmt.Sprintf(
+ "Failed adding a secure internet server with error: %s",
+ GetErrorTraceback(serverErr),
+ ),
+ )
state.GoBack()
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
@@ -372,7 +425,12 @@ func (state *VPNState) GetConfigSecureInternet(
config, configType, configErr := state.getConfig(server, forceTCP)
if configErr != nil {
- state.Logger.Error(fmt.Sprintf("Failed getting a secure internet configuration with error: %s", GetErrorTraceback(configErr)))
+ state.Logger.Error(
+ fmt.Sprintf(
+ "Failed getting a secure internet configuration with error: %s",
+ GetErrorTraceback(configErr),
+ ),
+ )
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
}
return config, configType, nil
@@ -425,14 +483,24 @@ func (state *VPNState) GetConfigInstituteAccess(url string, forceTCP bool) (stri
state.FSM.GoTransition(STATE_LOADING_SERVER)
server, serverErr := state.addInstituteServer(url)
if serverErr != nil {
- state.Logger.Error(fmt.Sprintf("Failed adding an institute access server with error: %s", GetErrorTraceback(serverErr)))
+ state.Logger.Error(
+ fmt.Sprintf(
+ "Failed adding an institute access server with error: %s",
+ GetErrorTraceback(serverErr),
+ ),
+ )
state.GoBack()
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
config, configType, configErr := state.getConfig(server, forceTCP)
if configErr != nil {
- state.Logger.Error(fmt.Sprintf("Failed getting an institute access server configuration with error: %s", GetErrorTraceback(configErr)))
+ state.Logger.Error(
+ fmt.Sprintf(
+ "Failed getting an institute access server configuration with error: %s",
+ GetErrorTraceback(configErr),
+ ),
+ )
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
}
return config, configType, nil
@@ -444,14 +512,24 @@ func (state *VPNState) GetConfigCustomServer(url string, forceTCP bool) (string,
server, serverErr := state.addCustomServer(url)
if serverErr != nil {
- state.Logger.Error(fmt.Sprintf("Failed adding a custom server with error: %s", GetErrorTraceback(serverErr)))
+ state.Logger.Error(
+ fmt.Sprintf(
+ "Failed adding a custom server with error: %s",
+ GetErrorTraceback(serverErr),
+ ),
+ )
state.GoBack()
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
config, configType, configErr := state.getConfig(server, forceTCP)
if configErr != nil {
- state.Logger.Error(fmt.Sprintf("Failed getting a custom server with error: %s", GetErrorTraceback(configErr)))
+ state.Logger.Error(
+ fmt.Sprintf(
+ "Failed getting a custom server with error: %s",
+ GetErrorTraceback(configErr),
+ ),
+ )
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr}
}
return config, configType, nil
@@ -472,7 +550,12 @@ func (state *VPNState) CancelOAuth() error {
currentServer, serverErr := state.Servers.GetCurrentServer()
if serverErr != nil {
- state.Logger.Warning(fmt.Sprintf("Failed cancelling OAuth, no server configured to cancel OAuth for (err: %v)", serverErr))
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed cancelling OAuth, no server configured to cancel OAuth for (err: %v)",
+ serverErr,
+ ),
+ )
return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
server.CancelOAuth(currentServer)
@@ -486,27 +569,43 @@ func (state *VPNState) ChangeSecureLocation() error {
state.Logger.Error("Failed changing secure internet location, not in the right state")
return &types.WrappedErrorMessage{
Message: errorMessage,
- Err: FSMWrongStateError{Got: state.FSM.Current, Want: STATE_NO_SERVER}.CustomError(),
+ Err: FSMWrongStateError{
+ Got: state.FSM.Current,
+ Want: STATE_NO_SERVER,
+ }.CustomError(),
}
}
askLocationErr := state.askSecureLocation()
if askLocationErr != nil {
- state.Logger.Error(fmt.Sprintf("Failed changing secure internet location, err: %s", GetErrorTraceback(askLocationErr)))
+ state.Logger.Error(
+ fmt.Sprintf(
+ "Failed changing secure internet location, err: %s",
+ GetErrorTraceback(askLocationErr),
+ ),
+ )
return &types.WrappedErrorMessage{Message: errorMessage, Err: askLocationErr}
}
// Go back to the main screen
- state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.GetSavedServers(), false)
+ state.FSM.GoTransitionWithData(STATE_NO_SERVER, state.Servers, false)
return nil
}
-func (state *VPNState) GetDiscoOrganizations() (string, error) {
+func (state *VPNState) GetDiscoOrganizations() (*types.DiscoveryOrganizations, error) {
orgs, orgsErr := state.Discovery.GetOrganizationsList()
if orgsErr != nil {
- state.Logger.Warning(fmt.Sprintf("Failed getting discovery organizations, Err: %s", GetErrorTraceback(orgsErr)))
- return "", &types.WrappedErrorMessage{Message: "failed getting discovery organizations list", Err: orgsErr}
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed getting discovery organizations, Err: %s",
+ GetErrorTraceback(orgsErr),
+ ),
+ )
+ return nil, &types.WrappedErrorMessage{
+ Message: "failed getting discovery organizations list",
+ Err: orgsErr,
+ }
}
return orgs, nil
}
@@ -514,8 +613,13 @@ func (state *VPNState) GetDiscoOrganizations() (string, error) {
func (state *VPNState) GetDiscoServers() (string, error) {
servers, serversErr := state.Discovery.GetServersList()
if serversErr != nil {
- state.Logger.Warning(fmt.Sprintf("Failed getting discovery servers, Err: %s", GetErrorTraceback(serversErr)))
- return "", &types.WrappedErrorMessage{Message: "failed getting discovery servers list", Err: serversErr}
+ state.Logger.Warning(
+ fmt.Sprintf("Failed getting discovery servers, Err: %s", GetErrorTraceback(serversErr)),
+ )
+ return "", &types.WrappedErrorMessage{
+ Message: "failed getting discovery servers list",
+ Err: serversErr,
+ }
}
return servers, nil
}
@@ -524,14 +628,21 @@ func (state *VPNState) SetProfileID(profileID string) error {
errorMessage := "failed to set the profile ID for the current server"
server, serverErr := state.Servers.GetCurrentServer()
if serverErr != nil {
- state.Logger.Warning(fmt.Sprintf("Failed setting a profile ID because no server configured, Err: %s", GetErrorTraceback(serverErr)))
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed setting a profile ID because no server configured, Err: %s",
+ GetErrorTraceback(serverErr),
+ ),
+ )
state.GoBack()
return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
}
base, baseErr := server.GetBase()
if baseErr != nil {
- state.Logger.Error(fmt.Sprintf("Failed setting a profile ID, Err: %s", GetErrorTraceback(serverErr)))
+ state.Logger.Error(
+ fmt.Sprintf("Failed setting a profile ID, Err: %s", GetErrorTraceback(serverErr)),
+ )
state.GoBack()
return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
@@ -541,7 +652,12 @@ func (state *VPNState) SetProfileID(profileID string) error {
func (state *VPNState) SetSearchServer() error {
if !state.FSM.HasTransition(STATE_SEARCH_SERVER) {
- state.Logger.Warning(fmt.Sprintf("Failed setting search server, wrong state %s", GetStateName(state.FSM.Current)))
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed setting search server, wrong state %s",
+ GetStateName(state.FSM.Current),
+ ),
+ )
return &types.WrappedErrorMessage{
Message: "failed to set search server",
Err: FSMWrongStateTransitionError{
@@ -558,21 +674,32 @@ func (state *VPNState) SetSearchServer() error {
func (state *VPNState) getServerInfoData() *server.ServerInfoScreen {
info, infoErr := state.Servers.GetCurrentServerInfo()
if infoErr != nil {
- state.Logger.Error(fmt.Sprintf("Failed getting server info data with error: %s", GetErrorTraceback(infoErr)))
+ state.Logger.Error(
+ fmt.Sprintf(
+ "Failed getting server info data with error: %s",
+ GetErrorTraceback(infoErr),
+ ),
+ )
}
return info
}
func (state *VPNState) SetConnected() error {
+ errorMessage := "failed to set connected"
if state.InFSMState(STATE_CONNECTED) {
// already connected, show no error
state.Logger.Warning("Already connected")
return nil
}
if !state.FSM.HasTransition(STATE_CONNECTED) {
- state.Logger.Warning(fmt.Sprintf("Failed setting connected, wrong state: %s", GetStateName(state.FSM.Current)))
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed setting connected, wrong state: %s",
+ GetStateName(state.FSM.Current),
+ ),
+ )
return &types.WrappedErrorMessage{
- Message: "failed to set connected",
+ Message: errorMessage,
Err: FSMWrongStateTransitionError{
Got: state.FSM.Current,
Want: STATE_CONNECTED,
@@ -580,7 +707,18 @@ func (state *VPNState) SetConnected() error {
}
}
- state.FSM.GoTransitionWithData(STATE_CONNECTED, state.getServerInfoData(), false)
+ currentServer, currentServerErr := state.Servers.GetCurrentServer()
+ if currentServerErr != nil {
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed setting connected, cannot get current server with error: %s",
+ GetErrorTraceback(currentServerErr),
+ ),
+ )
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
+ }
+
+ state.FSM.GoTransitionWithData(STATE_CONNECTED, currentServer, false)
return nil
}
@@ -591,7 +729,12 @@ func (state *VPNState) SetConnecting() error {
return nil
}
if !state.FSM.HasTransition(STATE_CONNECTING) {
- state.Logger.Warning(fmt.Sprintf("Failed setting connecting, wrong state: %s", GetStateName(state.FSM.Current)))
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed setting connecting, wrong state: %s",
+ GetStateName(state.FSM.Current),
+ ),
+ )
return &types.WrappedErrorMessage{
Message: "failed to set connecting",
Err: FSMWrongStateTransitionError{
@@ -606,15 +749,21 @@ func (state *VPNState) SetConnecting() error {
}
func (state *VPNState) SetDisconnecting() error {
+ errorMessage := "failed to set disconnecting"
if state.InFSMState(STATE_DISCONNECTING) {
// already disconnecting, show no error
state.Logger.Warning("Already disconnecting")
return nil
}
if !state.FSM.HasTransition(STATE_DISCONNECTING) {
- state.Logger.Warning(fmt.Sprintf("Failed setting disconnecting, wrong state: %s", GetStateName(state.FSM.Current)))
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed setting disconnecting, wrong state: %s",
+ GetStateName(state.FSM.Current),
+ ),
+ )
return &types.WrappedErrorMessage{
- Message: "failed to set disconnecting",
+ Message: errorMessage,
Err: FSMWrongStateTransitionError{
Got: state.FSM.Current,
Want: STATE_DISCONNECTING,
@@ -622,7 +771,18 @@ func (state *VPNState) SetDisconnecting() error {
}
}
- state.FSM.GoTransitionWithData(STATE_DISCONNECTING, state.getServerInfoData(), false)
+ currentServer, currentServerErr := state.Servers.GetCurrentServer()
+ if currentServerErr != nil {
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed setting disconnected, cannot get current server with error: %s",
+ GetErrorTraceback(currentServerErr),
+ ),
+ )
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
+ }
+
+ state.FSM.GoTransitionWithData(STATE_DISCONNECTING, currentServer, false)
return nil
}
@@ -634,7 +794,12 @@ func (state *VPNState) SetDisconnected(cleanup bool) error {
return nil
}
if !state.FSM.HasTransition(STATE_DISCONNECTED) {
- state.Logger.Warning(fmt.Sprintf("Failed setting disconnected, wrong state: %s", GetStateName(state.FSM.Current)))
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed setting disconnected, wrong state: %s",
+ GetStateName(state.FSM.Current),
+ ),
+ )
return &types.WrappedErrorMessage{
Message: errorMessage,
Err: FSMWrongStateTransitionError{
@@ -644,18 +809,23 @@ func (state *VPNState) SetDisconnected(cleanup bool) error {
}
}
+ currentServer, currentServerErr := state.Servers.GetCurrentServer()
+ if currentServerErr != nil {
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed setting disconnect, failed getting current server with error: %s",
+ GetErrorTraceback(currentServerErr),
+ ),
+ )
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
+ }
+
if cleanup {
// Do the /disconnect API call and go to disconnected after...
- currentServer, currentServerErr := state.Servers.GetCurrentServer()
- if currentServerErr != nil {
- state.Logger.Warning(fmt.Sprintf("Failed getting current server to send /disconnect API call, error: %s", GetErrorTraceback(currentServerErr)))
- return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
- }
-
server.Disconnect(currentServer)
}
- state.FSM.GoTransitionWithData(STATE_DISCONNECTED, state.getServerInfoData(), false)
+ state.FSM.GoTransitionWithData(STATE_DISCONNECTED, currentServer, false)
return nil
}
@@ -665,13 +835,23 @@ func (state *VPNState) RenewSession() error {
currentServer, currentServerErr := state.Servers.GetCurrentServer()
if currentServerErr != nil {
- state.Logger.Warning(fmt.Sprintf("Failed getting current server to renew, error: %s", GetErrorTraceback(currentServerErr)))
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed getting current server to renew, error: %s",
+ GetErrorTraceback(currentServerErr),
+ ),
+ )
return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
}
loginErr := state.ensureLogin(currentServer)
if loginErr != nil {
- state.Logger.Warning(fmt.Sprintf("Failed logging in server for renew, error: %s", GetErrorTraceback(loginErr)))
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed logging in server for renew, error: %s",
+ GetErrorTraceback(loginErr),
+ ),
+ )
return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr}
}
@@ -679,7 +859,9 @@ func (state *VPNState) RenewSession() error {
}
func (state *VPNState) ShouldRenewButton() bool {
- if !state.InFSMState(STATE_CONNECTED) && !state.InFSMState(STATE_CONNECTING) && !state.InFSMState(STATE_DISCONNECTED) && !state.InFSMState(STATE_DISCONNECTING) {
+ if !state.InFSMState(STATE_CONNECTED) && !state.InFSMState(STATE_CONNECTING) &&
+ !state.InFSMState(STATE_DISCONNECTED) &&
+ !state.InFSMState(STATE_DISCONNECTING) {
return false
}
@@ -717,3 +899,7 @@ func GetErrorTraceback(err error) string {
func GetErrorJSONString(err error) (string, error) {
return types.GetErrorJSONString(err)
}
+
+func (state *VPNState) GetTranslated(languages map[string]string) string {
+ return util.GetLanguageMatched(languages, state.Language)
+}
diff --git a/wrappers/python/main.py b/wrappers/python/main.py
index 1ab29cc..0bd2502 100644
--- a/wrappers/python/main.py
+++ b/wrappers/python/main.py
@@ -3,6 +3,8 @@ from eduvpn_common.state import State, StateType
import webbrowser
import json
import sys
+import time
+from typing import List
# Asks the user for a profile index
# It loops up until a valid input is given
@@ -27,6 +29,11 @@ def ask_profile_input(total: int) -> int:
# Sets up the callbacks using the provided class
def setup_callbacks(_eduvpn: eduvpn.EduVPN) -> None:
# The callback that starst OAuth
+ @_eduvpn.event.on(State.NO_SERVER, StateType.Enter)
+ def no_server(old_state: str, servers) -> None:
+ for server in servers:
+ print(type(server))
+ print(server)
# It needs to open the URL in the web browser
@_eduvpn.event.on(State.OAUTH_STARTED, StateType.Enter)
def oauth_initialized(old_state: str, url: str) -> None:
@@ -34,31 +41,30 @@ def setup_callbacks(_eduvpn: eduvpn.EduVPN) -> None:
webbrowser.open(url)
@_eduvpn.event.on(State.ASK_LOCATION, StateType.Enter)
- def ask_location(old_state: str, locations: str):
- print("Locations: ", locations)
- _eduvpn.set_secure_location("NL")
+ def ask_location(old_state: str, locations: List[str]):
+ _eduvpn.set_secure_location(locations[1])
- # The callback which asks the user for a profile
- @_eduvpn.event.on(State.ASK_PROFILE, StateType.Enter)
- def ask_profile(old_state: str, profiles: str):
- print("Multiple profiles found, you need to select a profile:")
+ ## The callback which asks the user for a profile
+ #@_eduvpn.event.on(State.ASK_PROFILE, StateType.Enter)
+ #def ask_profile(old_state: str, profiles: str):
+ # print("Multiple profiles found, you need to select a profile:")
- # Parse the profiles as JSON
- data = json.loads(profiles)
+ # # Parse the profiles as JSON
+ # data = json.loads(profiles)
- # Get a lits of profiles
- profile_strings = [x["profile_id"] for x in data["info"]["profile_list"]]
- total_profiles = len(profile_strings)
+ # # Get a lits of profiles
+ # profile_strings = [x["profile_id"] for x in data["info"]["profile_list"]]
+ # total_profiles = len(profile_strings)
- # Create a list of the strings to standard output
- for idx, profile in enumerate(profile_strings):
- print(f"{idx+1}. {profile}")
+ # # Create a list of the strings to standard output
+ # for idx, profile in enumerate(profile_strings):
+ # print(f"{idx+1}. {profile}")
- # Get the profile index from the user
- profile_index = ask_profile_input(total_profiles)
+ # # Get the profile index from the user
+ # profile_index = ask_profile_input(total_profiles)
- # Set the profile with the index
- _eduvpn.set_profile(profile_strings[profile_index])
+ # # Set the profile with the index
+ # _eduvpn.set_profile(profile_strings[profile_index])
# The main entry point
@@ -72,18 +78,13 @@ if __name__ == "__main__":
except Exception as e:
print("Failed registering:", e)
- server = input(
- "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): "
- )
-
- # Ensure we have a valid http prefix
- if not server.startswith("http"):
- # https by default
- server = "https://" + server
+ #server = input(
+ # "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): "
+ #)
# Get a Wireguard/OpenVPN config
try:
- config, config_type = _eduvpn.get_config_custom_server(server)
+ config, config_type = _eduvpn.get_config_secure_internet("https://idp.geant.org")
print(f"Got a config with type: {config_type} and contents:\n{config}")
except Exception as e:
print("Failed to connect:", e)
diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py
index 5b63651..db5484f 100644
--- a/wrappers/python/src/__init__.py
+++ b/wrappers/python/src/__init__.py
@@ -40,12 +40,66 @@ class ErrorLevel(Enum):
ERR_OTHER = 0
ERR_INFO = 1
+class cServerLocations(Structure):
+ _fields_ = [
+ ("locations", POINTER(c_char_p)),
+ ("total_locations", c_size_t)
+ ]
+
+class cDiscoveryOrganization(Structure):
+ _fields_ = [
+ ("display_name", c_char_p),
+ ("org_id", c_char_p),
+ ("secure_internet_home", c_char_p),
+ ("keyword_list", c_char_p),
+ ]
+
+class cDiscoveryOrganizations(Structure):
+ _fields_ = [
+ ("version", c_ulonglong),
+ ("organizations", POINTER(POINTER(cDiscoveryOrganization))),
+ ("total_organizations", c_size_t),
+ ]
+
+class cServerProfile(Structure):
+ _fields_ = [
+ ("identifier", c_char_p),
+ ("display_name", c_char_p),
+ ("default_gateway", c_int),
+ ]
+
+class cServerProfiles(Structure):
+ _fields_ = [
+ ("current", c_int),
+ ("profiles", POINTER(POINTER(cServerProfile))),
+ ("total_profiles", c_size_t),
+ ]
+
+class cServer(Structure):
+ _fields_ = [
+ ("identifier", c_char_p),
+ ("display_name", c_char_p),
+ ("country_code", c_char_p),
+ ("support_contact", POINTER(c_char_p)),
+ ("total_support_contact", c_size_t),
+ ("profiles", POINTER(cServerProfiles)),
+ ("expire_time", c_ulonglong),
+ ]
+
+class cServers(Structure):
+ _fields_ = [
+ ("custom_servers", POINTER(POINTER(cServer))),
+ ("total_custom", c_size_t),
+ ("institute_servers", POINTER(POINTER(cServer))),
+ ("total_institute", c_size_t),
+ ("secure_internet", POINTER(cServer)),
+ ]
class DataError(Structure):
_fields_ = [("data", c_void_p), ("error", c_void_p)]
-VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_char_p)
+VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p)
# Exposed functions
# We have to use c_void_p instead of c_char_p to free it properly
@@ -77,7 +131,7 @@ lib.Register.argtypes, lib.Register.restype = [
], c_void_p
lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [
c_char_p
-], DataError
+], c_void_p
lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError
lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None
lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p
@@ -96,8 +150,12 @@ lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c
lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p
lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int
lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p
+lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None
lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
+lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [c_void_p], None
+lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None
lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int
+lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], c_void_p
class WrappedError:
@@ -139,6 +197,8 @@ def get_ptr_error(ptr: c_void_p) -> Optional[WrappedError]:
if not error_json:
return None
+ if "level" not in error_json:
+ return error_string
level = error_json["level"]
traceback = error_json["traceback"]
cause = error_json["cause"]
@@ -149,6 +209,9 @@ def get_error(ptr: c_void_p) -> str:
error = get_ptr_error(ptr)
if not error:
return ""
+
+ if not isinstance(error, WrappedError):
+ return error
return error.cause
@@ -161,7 +224,6 @@ def get_data_error(data_error: DataError) -> Tuple[str, str]:
def get_bool(boolInt: c_int) -> bool:
return boolInt == 1
-
decode_map = {
c_int: get_bool,
c_void_p: get_error,
diff --git a/wrappers/python/src/discovery.py b/wrappers/python/src/discovery.py
new file mode 100644
index 0000000..80c08cf
--- /dev/null
+++ b/wrappers/python/src/discovery.py
@@ -0,0 +1,43 @@
+from . import lib, cDiscoveryOrganizations
+from ctypes import cast, POINTER
+
+
+class DiscoOrganization:
+ def __init__(self, display_name, org_id, secure_internet_home, keyword_list):
+ self.display_name = display_name
+ self.org_id = org_id
+ self.secure_internet_home = secure_internet_home
+ self.keyword_list = keyword_list
+
+
+class DiscoOrganizations:
+ def __init__(self, version, organizations):
+ self.version = version
+ self.organizations = organizations
+
+
+def get_disco_organization(ptr):
+ if not ptr:
+ return None
+
+ current_organization = ptr.contents
+ display_name = current_organization.display_name.decode("utf-8")
+ org_id = current_organization.org_id.decode("utf-8")
+ secure_internet_home = current_organization.secure_internet_home.decode("utf-8")
+ keyword_list = current_organization.keyword_list.decode("utf-8")
+ return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list)
+
+
+def get_disco_organizations(ptr):
+ if ptr:
+ orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents
+ organizations = []
+ if orgs.organizations:
+ for i in range(orgs.total_organizations):
+ current = get_disco_organization(orgs.organizations[i])
+ if current is None:
+ continue
+ organizations.append(current)
+ lib.FreeDiscoOrganizations(ptr)
+ return DiscoOrganizations(orgs.version, organizations)
+ return None
diff --git a/wrappers/python/src/event.py b/wrappers/python/src/event.py
index d0740f8..0e0f5ae 100644
--- a/wrappers/python/src/event.py
+++ b/wrappers/python/src/event.py
@@ -1,7 +1,8 @@
-from . import VPNStateChange
+from . import VPNStateChange, get_ptr_string
from enum import Enum
from typing import Callable
-from .state import StateType
+from .state import State, StateType
+from .server import get_locations, get_servers
EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback"
@@ -15,6 +16,15 @@ def class_state_transition(state: int, state_type: StateType) -> Callable:
return wrapper
+def convert_data(state: State, data):
+ if not data:
+ return None
+ if state is State.NO_SERVER:
+ return get_servers(data)
+ if state is State.OAUTH_STARTED:
+ return get_ptr_string(data)
+ if state is State.ASK_LOCATION:
+ return get_locations(data)
class EventHandler(object):
def __init__(self):
@@ -73,6 +83,7 @@ class EventHandler(object):
def run(self, old_state: int, new_state: int, data: str) -> None:
# First run leave transitions, then enter
# The state is done when the wait event finishes
- self.run_state(old_state, new_state, StateType.Leave, data)
- self.run_state(new_state, old_state, StateType.Enter, data)
- self.run_state(new_state, old_state, StateType.Wait, data)
+ converted = convert_data(new_state, data)
+ self.run_state(old_state, new_state, StateType.Leave, converted)
+ self.run_state(new_state, old_state, StateType.Enter, converted)
+ self.run_state(new_state, old_state, StateType.Wait, converted)
diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py
index b37842f..cbeadb5 100644
--- a/wrappers/python/src/main.py
+++ b/wrappers/python/src/main.py
@@ -1,8 +1,10 @@
from . import lib, VPNStateChange, encode_args, decode_res
from typing import Optional, Tuple
import threading
+from .discovery import get_disco_organizations
from .event import EventHandler
from .state import State, StateType
+from .server import get_servers
import json
eduvpn_objects = {}
@@ -26,7 +28,7 @@ def state_callback(name, old_state, new_state, data):
name = name.decode()
if name not in eduvpn_objects:
return
- eduvpn_objects[name].callback(State(old_state), State(new_state), data.decode())
+ eduvpn_objects[name].callback(State(old_state), State(new_state), data)
class EduVPN(object):
@@ -58,6 +60,12 @@ class EduVPN(object):
res = func(self.name.encode("utf-8"), *(args_gen))
return decode_res(func.restype)(res)
+ def go_function_custom_decode(self, func, decode_func, *args):
+ # The functions all have at least one arg type which is the name of the client
+ args_gen = encode_args(list(args), func.argtypes[1:])
+ res = func(self.name.encode("utf-8"), *(args_gen))
+ return decode_func(res)
+
def cancel_oauth(self) -> None:
cancel_oauth_err = self.go_function(lib.CancelOAuth)
@@ -91,10 +99,9 @@ class EduVPN(object):
return servers
def get_disco_organizations(self) -> str:
- organizations, organizations_err = self.go_function(lib.GetDiscoOrganizations)
-
- if organizations_err:
- raise Exception(organizations_err)
+ organizations = self.go_function_custom_decode(lib.GetDiscoOrganizations, decode_func=get_disco_organizations)
+ #if organizations_err:
+ # raise Exception(organizations_err)
return organizations
@@ -196,7 +203,7 @@ class EduVPN(object):
def event(self) -> EventHandler:
return self.event_handler
- def callback(self, old_state: State, new_state: State, data: str) -> None:
+ def callback(self, old_state: State, new_state: State, data) -> None:
self.event.run(old_state, new_state, data)
def set_profile(self, profile_id: str) -> None:
@@ -242,3 +249,9 @@ class EduVPN(object):
def in_fsm_state(self, state_id: State) -> bool:
return self.go_function(lib.InFSMState, state_id)
+
+ def get_saved_servers_old(self) -> str:
+ return self.go_function(lib.GetSavedServersOLD)
+
+ def get_saved_servers_new(self) -> str:
+ return self.go_function_custom_decode(lib.GetSavedServersNEW, decode_func=get_servers)
diff --git a/wrappers/python/src/server.py b/wrappers/python/src/server.py
new file mode 100644
index 0000000..b765ede
--- /dev/null
+++ b/wrappers/python/src/server.py
@@ -0,0 +1,158 @@
+from . import lib, cServers, cServerLocations
+from ctypes import cast, POINTER
+
+
+class Profile:
+ def __init__(self, identifier, display_name, default_gateway: bool):
+ self.identifier = identifier
+ self.display_name = display_name
+ self.default_gateway = default_gateway
+
+ def __str__(self):
+ return f"Profile: {self.display_name}"
+
+
+class Server:
+ def __init__(self, url, display_name, profiles, current_profile, expire_time):
+ self.url = url
+ self.display_name = display_name
+ self.profiles = profiles
+ self.current_profile = None
+ if current_profile < len(profiles):
+ self.current_profile = profiles[current_profile]
+ self.expire_time = expire_time
+
+ def __str__(self):
+ return f"Server: {self.url}, with current profile: {self.current_profile}"
+
+
+class InstituteServer(Server):
+ def __init__(
+ self, url, display_name, support_contact, profiles, current_profile, expire_time
+ ):
+ super().__init__(url, display_name, profiles, current_profile, expire_time)
+ self.support_contact = support_contact
+
+ def __str__(self):
+ return f"Institute Server: {self.display_name}"
+
+
+class SecureInternetServer(Server):
+ def __init__(
+ self,
+ url,
+ display_name,
+ support_contact,
+ profiles,
+ current_profile,
+ expire_time,
+ country_code,
+ ):
+ super().__init__(url, display_name, profiles, current_profile, expire_time)
+ self.support_contact = support_contact
+ self.country_code = country_code
+
+ def __str__(self):
+ return f"Secure Internet Server: {self.display_name} with country {self.country_code}"
+
+
+def get_type_for_str(type_str: str):
+ if type_str is "secure_internet":
+ return SecureInternetServer
+ if type_str is "custom_server":
+ return Server
+ return InstituteServer
+
+
+def get_server(ptr, _type=None):
+ if not ptr:
+ return None
+
+ current_server = ptr.contents
+ if _type is None:
+ _type = get_type_for_str(current_server.server_type.decode("utf-8"))
+
+ identifier = current_server.identifier.decode("utf-8")
+ display_name = current_server.display_name.decode("utf-8")
+
+ if _type is not Server:
+ support_contact = []
+ for i in range(current_server.total_support_contact):
+ support_contact.append(current_server.support_contact[i].decode("utf-8"))
+ profiles = []
+ if not current_server.profiles:
+ return None
+
+ _profiles = current_server.profiles.contents
+ current_profile = _profiles.current
+ for i in range(_profiles.total_profiles):
+ if not _profiles.profiles or not _profiles.profiles[i]:
+ return None
+ profile = _profiles.profiles[i].contents
+ profiles.append(
+ Profile(
+ profile.identifier.decode("utf-8"),
+ profile.display_name.decode("utf-8"),
+ profile.default_gateway == 1,
+ )
+ )
+
+ if _type is SecureInternetServer:
+ return SecureInternetServer(
+ identifier,
+ display_name,
+ support_contact,
+ profiles,
+ current_profile,
+ current_server.expire_time,
+ current_server.country_code.decode("utf-8"),
+ )
+ if _type is InstituteServer:
+ return InstituteServer(
+ identifier,
+ display_name,
+ support_contact,
+ profiles,
+ current_profile,
+ current_server.expire_time,
+ )
+ return Server(
+ identifier, display_name, profiles, current_profile, current_server.expire_time
+ )
+
+
+def get_servers(ptr):
+ if ptr:
+ returned = []
+ servers = cast(ptr, POINTER(cServers)).contents
+ if servers.custom_servers:
+ for i in range(servers.total_custom):
+ current = get_server(servers.custom_servers[i], Server)
+ if current is None:
+ continue
+ returned.append(current)
+
+ if servers.institute_servers:
+ for i in range(servers.total_institute):
+ current = get_server(servers.institute_servers[i], InstituteServer)
+ if current is None:
+ continue
+ returned.append(current)
+
+ if servers.secure_internet:
+ current = get_server(servers.secure_internet, SecureInternetServer)
+ if current is not None:
+ returned.append(current)
+ lib.FreeServers(ptr)
+ return returned
+ return None
+
+def get_locations(ptr):
+ if ptr:
+ locations = cast(ptr, POINTER(cServerLocations)).contents
+ location_list = []
+ for i in range(locations.total_locations):
+ location_list.append(locations.locations[i].decode("utf-8"))
+ lib.FreeSecureLocations(ptr)
+ return location_list
+ return None
diff --git a/wrappers/python/tests.py b/wrappers/python/tests.py
index 60d3cce..679eda0 100644
--- a/wrappers/python/tests.py
+++ b/wrappers/python/tests.py
@@ -24,7 +24,7 @@ class ConfigTests(unittest.TestCase):
@_eduvpn.event.on(State.OAUTH_STARTED, StateType.Enter)
def oauth_initialized(old_state, url_json):
- login_eduvpn(json.loads(url_json))
+ login_eduvpn(url_json)
server_uri = os.getenv("SERVER_URI")
if not server_uri: