summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-16 10:46:28 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-09-16 10:46:28 +0200
commit4bf1273c3f446ac3195fb700ec41c7cae7d20ac9 (patch)
treecec8d9e405b7d6786023ca9b921a6f0473d28a71
parent02db081c85e56e6472c2f39e6a623fa4cdf359c4 (diff)
Discovery: Expose c types
-rw-r--r--exports/c/disco.h19
-rw-r--r--exports/c/servers.h1
-rw-r--r--exports/disco.go98
-rw-r--r--exports/exports.go21
-rw-r--r--exports/servers.go65
-rw-r--r--internal/discovery/discovery.go8
-rw-r--r--internal/server/common.go103
-rw-r--r--internal/types/server.go1
-rw-r--r--state.go41
-rw-r--r--wrappers/python/src/__init__.py73
-rw-r--r--wrappers/python/src/discovery.py85
-rw-r--r--wrappers/python/src/event.py26
-rw-r--r--wrappers/python/src/main.py20
-rw-r--r--wrappers/python/src/server.py109
14 files changed, 455 insertions, 215 deletions
diff --git a/exports/c/disco.h b/exports/c/disco.h
index 41d59fa..8fa07a4 100644
--- a/exports/c/disco.h
+++ b/exports/c/disco.h
@@ -1,6 +1,25 @@
// for size_t
#include <stddef.h>
+typedef struct discoveryServer {
+ const char* authentication_url_template;
+ const char* base_url;
+ const char* country_code;
+ const char* display_name;
+ const char* keyword_list;
+ const char** public_key_list;
+ size_t total_public_keys;
+ const char* server_type;
+ const char** support_contact;
+ size_t total_support_contact;
+} discoveryServer;
+
+typedef struct discoveryServers {
+ unsigned long long int version;
+ discoveryServer** servers;
+ size_t total_servers;
+} discoveryServers;
+
typedef struct discoveryOrganization {
const char* display_name;
const char* org_id;
diff --git a/exports/c/servers.h b/exports/c/servers.h
index 39e52a2..1b6cca9 100644
--- a/exports/c/servers.h
+++ b/exports/c/servers.h
@@ -26,6 +26,7 @@ typedef struct serverLocations {
typedef struct server {
const char* identifier;
const char* display_name;
+ const char* server_type;
const char* country_code;
const char** support_contact;
size_t total_support_contact;
diff --git a/exports/disco.go b/exports/disco.go
index 9ee2af9..a08f8ba 100644
--- a/exports/disco.go
+++ b/exports/disco.go
@@ -33,9 +33,8 @@ func getCPtrDiscoOrganizations(
organizations *types.DiscoveryOrganizations,
) (C.size_t, **C.discoveryOrganization) {
totalOrganizations := C.size_t(len(organizations.List))
- var organizationsPtr **C.discoveryOrganization
if totalOrganizations > 0 {
- organizationsPtr = (**C.discoveryOrganization)(
+ organizationsPtr := (**C.discoveryOrganization)(
C.malloc(totalOrganizations * C.size_t(unsafe.Sizeof(uintptr(0)))),
)
cOrganizations := (*[1<<30 - 1]*C.discoveryOrganization)(unsafe.Pointer(organizationsPtr))[:totalOrganizations:totalOrganizations]
@@ -45,8 +44,52 @@ func getCPtrDiscoOrganizations(
cOrganizations[index] = cOrganization
index += 1
}
+ return totalOrganizations, organizationsPtr
}
- return totalOrganizations, organizationsPtr
+ return 0, nil
+}
+
+func getCPtrDiscoServer(
+ state *eduvpn.VPNState,
+ server *types.DiscoveryServer,
+) *C.discoveryServer {
+ returnedStruct := (*C.discoveryServer)(
+ C.malloc(C.size_t(unsafe.Sizeof(C.discoveryServer{}))),
+ )
+ returnedStruct.authentication_url_template = C.CString(server.AuthenticationURLTemplate)
+ returnedStruct.base_url = C.CString(server.BaseURL)
+ returnedStruct.country_code = C.CString(server.CountryCode)
+ returnedStruct.display_name = C.CString(state.GetTranslated(server.DisplayName))
+ returnedStruct.keyword_list = C.CString(state.GetTranslated(server.KeywordList))
+ returnedStruct.total_public_keys, returnedStruct.public_key_list = getCPtrListStrings(
+ server.PublicKeyList,
+ )
+ returnedStruct.server_type = C.CString(server.Type)
+ returnedStruct.total_support_contact, returnedStruct.support_contact = getCPtrListStrings(
+ server.SupportContact,
+ )
+ return returnedStruct
+}
+
+func getCPtrDiscoServers(
+ state *eduvpn.VPNState,
+ servers *types.DiscoveryServers,
+) (C.size_t, **C.discoveryServer) {
+ totalServers := C.size_t(len(servers.List))
+ if totalServers > 0 {
+ serversPtr := (**C.discoveryServer)(
+ C.malloc(totalServers * C.size_t(unsafe.Sizeof(uintptr(0)))),
+ )
+ cServers := (*[1<<30 - 1]*C.discoveryServer)(unsafe.Pointer(serversPtr))
+ index := 0
+ for _, server := range servers.List {
+ cServer := getCPtrDiscoServer(state, &server)
+ cServers[index] = cServer
+ index += 1
+ }
+ return totalServers, serversPtr
+ }
+ return 0, nil
}
func freeDiscoOrganization(cOrganization *C.discoveryOrganization) {
@@ -57,6 +100,30 @@ func freeDiscoOrganization(cOrganization *C.discoveryOrganization) {
C.free(unsafe.Pointer(cOrganization))
}
+func freeDiscoServer(cServer *C.discoveryServer) {
+ C.free(unsafe.Pointer(cServer.authentication_url_template))
+ C.free(unsafe.Pointer(cServer.base_url))
+ C.free(unsafe.Pointer(cServer.country_code))
+ C.free(unsafe.Pointer(cServer.display_name))
+ C.free(unsafe.Pointer(cServer.keyword_list))
+ freeCListStrings(cServer.public_key_list, cServer.total_public_keys)
+ C.free(unsafe.Pointer(cServer.server_type))
+ freeCListStrings(cServer.support_contact, cServer.total_support_contact)
+ C.free(unsafe.Pointer(cServer))
+}
+
+//export FreeDiscoServers
+func FreeDiscoServers(cServers *C.discoveryServers) {
+ if cServers.total_servers > 0 {
+ servers := (*[1<<30 - 1]*C.discoveryServer)(unsafe.Pointer(cServers.servers))[:cServers.total_servers:cServers.total_servers]
+ for i := C.size_t(0); i < cServers.total_servers; i++ {
+ freeDiscoServer(servers[i])
+ }
+ C.free(unsafe.Pointer(cServers.servers))
+ }
+ C.free(unsafe.Pointer(cServers))
+}
+
//export FreeDiscoOrganizations
func FreeDiscoOrganizations(cOrganizations *C.discoveryOrganizations) {
if cOrganizations.total_organizations > 0 {
@@ -69,6 +136,31 @@ func FreeDiscoOrganizations(cOrganizations *C.discoveryOrganizations) {
C.free(unsafe.Pointer(cOrganizations))
}
+//export GetDiscoServers
+func GetDiscoServers(name *C.char) *C.discoveryServers {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ // TODO
+ if stateErr != nil {
+ panic(stateErr)
+ }
+ servers, serversErr := state.GetDiscoServers()
+ // TODO
+ if serversErr != nil {
+ panic(serversErr)
+ }
+
+ returnedStruct := (*C.discoveryServers)(
+ C.malloc(C.size_t(unsafe.Sizeof(C.discoveryServers{}))),
+ )
+ returnedStruct.version = C.ulonglong(servers.Version)
+ returnedStruct.total_servers, returnedStruct.servers = getCPtrDiscoServers(
+ state,
+ servers,
+ )
+ return returnedStruct
+}
+
//export GetDiscoOrganizations
func GetDiscoOrganizations(name *C.char) *C.discoveryOrganizations {
nameStr := C.GoString(name)
diff --git a/exports/exports.go b/exports/exports.go
index 797a192..0144500 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -38,6 +38,16 @@ func GetStateData(
}
case eduvpn.STATE_ASK_LOCATION:
return (unsafe.Pointer)(getTransitionSecureLocations(data))
+ case eduvpn.STATE_ASK_PROFILE:
+ return (unsafe.Pointer)(getTransitionProfiles(data))
+ case eduvpn.STATE_DISCONNECTED:
+ return (unsafe.Pointer)(getTransitionServer(state, data))
+ case eduvpn.STATE_DISCONNECTING:
+ return (unsafe.Pointer)(getTransitionServer(state, data))
+ case eduvpn.STATE_CONNECTING:
+ return (unsafe.Pointer)(getTransitionServer(state, data))
+ case eduvpn.STATE_CONNECTED:
+ return (unsafe.Pointer)(getTransitionServer(state, data))
default:
return nil
}
@@ -229,17 +239,6 @@ func GetConfigCustomServer(name *C.char, url *C.char, forceTCP C.int) (*C.char,
return getConfigJSON(config, configType), C.CString(ErrorToString(configErr))
}
-//export GetDiscoServers
-func GetDiscoServers(name *C.char) (*C.char, *C.char) {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return nil, C.CString(ErrorToString(stateErr))
- }
- servers, serversErr := state.GetDiscoServers()
- return C.CString(servers), C.CString(ErrorToString(serversErr))
-}
-
//export SetProfileID
func SetProfileID(name *C.char, data *C.char) *C.char {
nameStr := C.GoString(name)
diff --git a/exports/servers.go b/exports/servers.go
index f92c08e..d33ac1f 100644
--- a/exports/servers.go
+++ b/exports/servers.go
@@ -51,7 +51,7 @@ func getCPtrProfiles(serverProfiles *server.ServerProfileInfo) *C.serverProfiles
profiles[index] = getCPtrProfile(&profile)
index += 1
}
- // TODO: DO CURRENT PROFILE
+ cProfiles.current = C.int(serverProfiles.GetCurrentProfileIndex())
cProfiles.profiles = (**C.serverProfile)(profilesPtr)
}
return cProfiles
@@ -59,7 +59,8 @@ func getCPtrProfiles(serverProfiles *server.ServerProfileInfo) *C.serverProfiles
// Free the profiles by looping through them if there are any
// Also free the pointer itself
-func freeCProfiles(profiles *C.serverProfiles) {
+//export FreeProfiles
+func FreeProfiles(profiles *C.serverProfiles) {
// We should only free the profiles if we have them (which we should)
if profiles.total_profiles > 0 {
// Convert it to a go slice
@@ -119,12 +120,21 @@ func freeCListStrings(allStrings **C.char, totalStrings C.size_t) {
// Function for getting the server,
// It gets the main state as a pointer as we need to convert some string maps to localized strings
// It gets the base information for a server as well
-func getServer(state *eduvpn.VPNState, base *eduvpn.VPNServerBase) *C.server {
+func getCPtrServer(state *eduvpn.VPNState, base *eduvpn.VPNServerBase) *C.server {
// Allocation using malloc and the size of the struct
server := (*C.server)(C.malloc(C.size_t(unsafe.Sizeof(C.server{}))))
// String allocation and translate the display name
- server.identifier = C.CString(base.URL)
+ identifier := base.URL
+ countryCode := ""
+ if base.Type == "secure_internet" {
+ identifier = state.Servers.SecureInternetHomeServer.HomeOrganizationID
+ countryCode = state.Servers.SecureInternetHomeServer.CurrentLocation
+ }
+
+ server.identifier = C.CString(identifier)
server.display_name = C.CString(state.GetTranslated(base.DisplayName))
+ server.country_code = C.CString(countryCode)
+ server.server_type = C.CString(base.Type)
// Call the helper to get the list of support contacts
server.total_support_contact, server.support_contact = getCPtrListStrings(
base.SupportContact,
@@ -133,22 +143,26 @@ func getServer(state *eduvpn.VPNState, base *eduvpn.VPNServerBase) *C.server {
// No endtime is given if we get servers when it has been partially initialised
if base.EndTime.IsZero() {
server.expire_time = C.ulonglong(0)
+ } else {
+ // The expire time should be stored as an unsigned long long in unix itme
+ server.expire_time = C.ulonglong(base.EndTime.Unix())
}
- // The expire time should be stored as an unsigned long long in unix itme
- server.expire_time = C.ulonglong(base.EndTime.Unix())
return server
}
// Function for freeing a single server
// Gets the pointer to C struct
-func freeServer(info *C.server) {
+//export FreeServer
+func FreeServer(info *C.server) {
// Free strings
C.free(unsafe.Pointer(info.identifier))
C.free(unsafe.Pointer(info.display_name))
+ C.free(unsafe.Pointer(info.country_code))
+ C.free(unsafe.Pointer(info.server_type))
// Free arrays
freeCListStrings(info.support_contact, info.total_support_contact)
- freeCProfiles(info.profiles)
+ FreeProfiles(info.profiles)
// Free the struct itself
C.free(unsafe.Pointer(info))
@@ -166,10 +180,11 @@ func getCPtrServers(
servers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(serversPtr))[:totalServers:totalServers]
index := 0
for _, server := range serverMap {
- cServer := getServer(state, &server.Base)
+ cServer := getCPtrServer(state, &server.Base)
servers[index] = cServer
index += 1
}
+ return totalServers, serversPtr
}
return C.size_t(0), nil
}
@@ -182,7 +197,7 @@ func FreeServers(cServers *C.servers) {
if cServers.total_custom > 0 {
customServers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(cServers.custom_servers))[:cServers.total_custom:cServers.total_custom]
for i := C.size_t(0); i < cServers.total_custom; i++ {
- freeServer(customServers[i])
+ FreeServer(customServers[i])
}
C.free(unsafe.Pointer(cServers.custom_servers))
}
@@ -191,14 +206,13 @@ func FreeServers(cServers *C.servers) {
instituteServers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(cServers.institute_servers))[:cServers.total_institute:cServers.total_institute]
for i := C.size_t(0); i < cServers.total_institute; i++ {
- freeServer(instituteServers[i])
+ FreeServer(instituteServers[i])
}
C.free(unsafe.Pointer(cServers.institute_servers))
}
// Free the secure internet server if there is one
if cServers.secure_internet_server != nil {
- C.free(unsafe.Pointer(cServers.secure_internet_server.country_code))
- freeServer(cServers.secure_internet_server)
+ FreeServer(cServers.secure_internet_server)
}
// Free the structure itself
C.free(unsafe.Pointer(cServers))
@@ -219,7 +233,7 @@ func getSavedServersWithOptions(state *eduvpn.VPNState, servers *server.Servers)
secureInternetBase, secureInternetBaseErr := servers.SecureInternetHomeServer.GetBase()
if secureInternetBaseErr == nil && secureInternetBase != nil {
// FIXME: log error?
- secureServerPtr = getServer(state, secureInternetBase)
+ secureServerPtr = getCPtrServer(state, secureInternetBase)
// Give a new identifier
C.free(unsafe.Pointer(secureServerPtr.identifier))
secureServerPtr.identifier = C.CString(servers.SecureInternetHomeServer.HomeOrganizationID)
@@ -259,11 +273,11 @@ func getTransitionDataServers(state *eduvpn.VPNState, data interface{}) *C.serve
//export FreeSecureLocations
func FreeSecureLocations(locations *C.serverLocations) {
- freeCListStrings(locations.locations, locations.total_locations);
+ freeCListStrings(locations.locations, locations.total_locations)
C.free(unsafe.Pointer(locations))
}
-func getTransitionSecureLocations(data interface{}) (*C.serverLocations) {
+func getTransitionSecureLocations(data interface{}) *C.serverLocations {
if locations, ok := data.([]string); ok {
returnedStruct := (*C.serverLocations)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{}))))
returnedStruct.total_locations, returnedStruct.locations = getCPtrListStrings(locations)
@@ -271,3 +285,22 @@ func getTransitionSecureLocations(data interface{}) (*C.serverLocations) {
}
return nil
}
+
+func getTransitionProfiles(data interface{}) *C.serverProfiles {
+ if profiles, ok := data.(*server.ServerProfileInfo); ok {
+ return getCPtrProfiles(profiles)
+ }
+ return nil
+}
+
+func getTransitionServer(state *eduvpn.VPNState, data interface{}) *C.server {
+ if server, ok := data.(server.Server); ok {
+ base, baseErr := server.GetBase()
+ if baseErr != nil {
+ // TODO: LOG
+ return nil
+ }
+ return getCPtrServer(state, base)
+ }
+ return nil
+}
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index 21125cb..b3b438c 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -182,15 +182,15 @@ func (discovery *Discovery) GetOrganizationsList() (*types.DiscoveryOrganization
}
// Get the server list
-func (discovery *Discovery) GetServersList() (string, error) {
+func (discovery *Discovery) GetServersList() (*types.DiscoveryServers, error) {
if !discovery.DetermineServersUpdate() {
- return discovery.Servers.RawString, nil
+ return &discovery.Servers, nil
}
file := "server_list.json"
body, bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers)
if bodyErr != nil {
// Return previous with an error
- return discovery.Servers.RawString, &types.WrappedErrorMessage{
+ return &discovery.Servers, &types.WrappedErrorMessage{
Message: "failed getting servers in Discovery",
Err: bodyErr,
}
@@ -198,7 +198,7 @@ func (discovery *Discovery) GetServersList() (string, error) {
// Update servers timestamp
discovery.Servers.RawString = body
discovery.Servers.Timestamp = util.GetCurrentTime()
- return discovery.Servers.RawString, nil
+ return &discovery.Servers, nil
}
type GetOrgByIDNotFoundError struct {
diff --git a/internal/server/common.go b/internal/server/common.go
index 64b8079..9c941cf 100644
--- a/internal/server/common.go
+++ b/internal/server/common.go
@@ -64,6 +64,18 @@ type ServerProfileInfo struct {
} `json:"info"`
}
+func (info ServerProfileInfo) GetCurrentProfileIndex() int {
+ index := 0
+ for _, profile := range info.Info.ProfileList {
+ if profile.ID == info.Current {
+ return index
+ }
+ index += 1
+ }
+ // Default is 'first' profile
+ return 0
+}
+
type ServerEndpointList struct {
API string `json:"api_endpoint"`
Authorization string `json:"authorization_endpoint"`
@@ -118,97 +130,6 @@ func (servers *Servers) GetCurrentServer() (Server, error) {
return server, nil
}
-type ServersConfiguredScreen struct {
- CustomServers []ServerInfoScreen `json:"custom_servers"`
- InstituteAccessServers []ServerInfoScreen `json:"institute_access_servers"`
- SecureInternetServer *ServerInfoScreen `json:"secure_internet_server"`
-}
-
-type ServerInfoScreen struct {
- Identifier string `json:"identifier"`
- DisplayName map[string]string `json:"display_name"`
- CountryCode string `json:"country_code,omitempty"`
- SupportContact []string `json:"support_contact"`
- Profiles ServerProfileInfo `json:"profiles"`
- ExpireTime int64 `json:"expire_time"`
- Type string `json:"server_type"`
-}
-
-func getServerInfoScreen(base ServerBase) ServerInfoScreen {
- serverInfoScreen := ServerInfoScreen{}
- serverInfoScreen.Identifier = base.URL
- serverInfoScreen.DisplayName = base.DisplayName
- serverInfoScreen.SupportContact = base.SupportContact
- serverInfoScreen.Profiles = base.Profiles
-
- // If we still have the default end time, return 0
- // Such that clients will still be able to parse it correctly
- if base.EndTime.IsZero() {
- serverInfoScreen.ExpireTime = 0
- } else {
- serverInfoScreen.ExpireTime = base.EndTime.Unix()
- }
- serverInfoScreen.Type = base.Type
-
- return serverInfoScreen
-}
-
-func (servers *Servers) GetServersConfigured() *ServersConfiguredScreen {
- customServersInfo := []ServerInfoScreen{}
- instituteServersInfo := []ServerInfoScreen{}
- var secureInternetServerInfo *ServerInfoScreen = nil
-
- for _, server := range servers.CustomServers.Map {
- serverInfoScreen := getServerInfoScreen(server.Base)
- customServersInfo = append(customServersInfo, serverInfoScreen)
- }
-
- for _, server := range servers.InstituteServers.Map {
- serverInfoScreen := getServerInfoScreen(server.Base)
- instituteServersInfo = append(instituteServersInfo, serverInfoScreen)
- }
-
- secureInternetBase, secureInternetBaseErr := servers.SecureInternetHomeServer.GetBase()
-
- if secureInternetBaseErr == nil && secureInternetBase != nil {
- // FIXME: log error?
- secureInternetServerInfoReturned := getServerInfoScreen(*secureInternetBase)
- secureInternetServerInfo = &secureInternetServerInfoReturned
- secureInternetServerInfo.Identifier = servers.SecureInternetHomeServer.HomeOrganizationID
- secureInternetServerInfo.CountryCode = servers.SecureInternetHomeServer.CurrentLocation
- }
-
- return &ServersConfiguredScreen{
- CustomServers: customServersInfo,
- InstituteAccessServers: instituteServersInfo,
- SecureInternetServer: secureInternetServerInfo,
- }
-}
-
-func (servers *Servers) GetCurrentServerInfo() (*ServerInfoScreen, error) {
- errorMessage := "failed getting current server info"
-
- currentServer, currentServerErr := servers.GetCurrentServer()
- if currentServerErr != nil {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
- }
-
- base, baseErr := currentServer.GetBase()
-
- if baseErr != nil {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
- }
-
- serverInfoScreen := getServerInfoScreen(*base)
-
- if servers.IsType == SecureInternetServerType {
- serverInfoScreen.Identifier = servers.SecureInternetHomeServer.HomeOrganizationID
- serverInfoScreen.CountryCode = servers.SecureInternetHomeServer.CurrentLocation
- }
-
- return &serverInfoScreen, nil
-}
-
func (servers *Servers) addInstituteAndCustom(
discoServer *types.DiscoveryServer,
isCustom bool,
diff --git a/internal/types/server.go b/internal/types/server.go
index 33a6e9c..48f94fb 100644
--- a/internal/types/server.go
+++ b/internal/types/server.go
@@ -62,6 +62,7 @@ type DiscoveryServer struct {
BaseURL string `json:"base_url"`
CountryCode string `json:"country_code"`
DisplayName DiscoMapOrString `json:"display_name,omitempty"`
+ KeywordList DiscoMapOrString `json:"keyword_list"`
PublicKeyList []string `json:"public_key_list"`
Type string `json:"server_type"`
SupportContact []string `json:"support_contact"`
diff --git a/state.go b/state.go
index 37ae939..9f7bd1b 100644
--- a/state.go
+++ b/state.go
@@ -15,7 +15,6 @@ import (
)
type (
- ServerInfo = server.ServerInfoScreen
VPNServerBase = server.ServerBase
)
@@ -42,10 +41,6 @@ type VPNState struct {
Debug bool `json:"-"`
}
-func (state *VPNState) GetSavedServers() *server.ServersConfiguredScreen {
- return state.Servers.GetServersConfigured()
-}
-
// Register initializes the state with the following parameters:
// - name: the name of the client
// - directory: the directory where the config files are stored. Absolute or relative
@@ -95,7 +90,7 @@ func (state *VPNState) Register(
_, currentServerErr := state.Servers.GetCurrentServer()
// Only actually return the error if we have no disco servers and no current server
- if discoServersErr != nil && discoServers == "" && currentServerErr != nil {
+ if discoServersErr != nil && discoServers.Version == 0 && currentServerErr != nil {
state.Logger.Error(
fmt.Sprintf(
"No configured servers, discovery servers is empty and no servers with error: %s",
@@ -263,7 +258,6 @@ func (state *VPNState) getConfig(
// Signal the server display info
state.FSM.GoTransitionWithData(STATE_DISCONNECTED, currentServer, false)
-
// Save the config
state.Config.Save(&state)
@@ -621,13 +615,13 @@ func (state *VPNState) GetDiscoOrganizations() (*types.DiscoveryOrganizations, e
return orgs, nil
}
-func (state *VPNState) GetDiscoServers() (string, error) {
+func (state *VPNState) GetDiscoServers() (*types.DiscoveryServers, error) {
servers, serversErr := state.Discovery.GetServersList()
if serversErr != nil {
state.Logger.Warning(
fmt.Sprintf("Failed getting discovery servers, Err: %s", GetErrorTraceback(serversErr)),
)
- return "", &types.WrappedErrorMessage{
+ return nil, &types.WrappedErrorMessage{
Message: "failed getting discovery servers list",
Err: serversErr,
}
@@ -682,19 +676,6 @@ func (state *VPNState) SetSearchServer() error {
return nil
}
-func (state *VPNState) getServerInfoData() *server.ServerInfoScreen {
- info, infoErr := state.Servers.GetCurrentServerInfo()
- if infoErr != nil {
- state.Logger.Error(
- fmt.Sprintf(
- "Failed getting server info data with error: %s",
- GetErrorTraceback(infoErr),
- ),
- )
- }
- return info
-}
-
func (state *VPNState) SetConnected() error {
errorMessage := "failed to set connected"
if state.InFSMState(STATE_CONNECTED) {
@@ -734,6 +715,7 @@ func (state *VPNState) SetConnected() error {
}
func (state *VPNState) SetConnecting() error {
+ errorMessage := "failed to set connecting"
if state.InFSMState(STATE_CONNECTING) {
// already loading connection, show no error
state.Logger.Warning("Already connecting")
@@ -747,7 +729,7 @@ func (state *VPNState) SetConnecting() error {
),
)
return &types.WrappedErrorMessage{
- Message: "failed to set connecting",
+ Message: errorMessage,
Err: FSMWrongStateTransitionError{
Got: state.FSM.Current,
Want: STATE_CONNECTING,
@@ -755,7 +737,18 @@ func (state *VPNState) SetConnecting() error {
}
}
- state.FSM.GoTransition(STATE_CONNECTING)
+ currentServer, currentServerErr := state.Servers.GetCurrentServer()
+ if currentServerErr != nil {
+ state.Logger.Warning(
+ fmt.Sprintf(
+ "Failed setting connecting, cannot get current server with error: %s",
+ GetErrorTraceback(currentServerErr),
+ ),
+ )
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
+ }
+
+ state.FSM.GoTransitionWithData(STATE_CONNECTING, currentServer, false)
return nil
}
diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py
index ece4b46..ca07143 100644
--- a/wrappers/python/src/__init__.py
+++ b/wrappers/python/src/__init__.py
@@ -5,6 +5,7 @@ import pathlib
import platform
from typing import Tuple, Optional
import json
+from typing import List
_lib_prefixes = defaultdict(
lambda: "lib",
@@ -40,11 +41,10 @@ class ErrorLevel(Enum):
ERR_OTHER = 0
ERR_INFO = 1
+
class cServerLocations(Structure):
- _fields_ = [
- ("locations", POINTER(c_char_p)),
- ("total_locations", c_size_t)
- ]
+ _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)]
+
class cDiscoveryOrganization(Structure):
_fields_ = [
@@ -54,6 +54,7 @@ class cDiscoveryOrganization(Structure):
("keyword_list", c_char_p),
]
+
class cDiscoveryOrganizations(Structure):
_fields_ = [
("version", c_ulonglong),
@@ -61,6 +62,30 @@ class cDiscoveryOrganizations(Structure):
("total_organizations", c_size_t),
]
+
+class cDiscoveryServer(Structure):
+ _fields_ = [
+ ("authentication_url_template", c_char_p),
+ ("base_url", c_char_p),
+ ("country_code", c_char_p),
+ ("display_name", c_char_p),
+ ("keyword_list", c_char_p),
+ ("public_key_list", POINTER(c_char_p)),
+ ("total_public_keys", c_size_t),
+ ("server_type", c_char_p),
+ ("support_contact", POINTER(c_char_p)),
+ ("total_support_contact", c_size_t),
+ ]
+
+
+class cDiscoveryServers(Structure):
+ _fields_ = [
+ ("version", c_ulonglong),
+ ("servers", POINTER(POINTER(cDiscoveryServer))),
+ ("total_servers", c_size_t),
+ ]
+
+
class cServerProfile(Structure):
_fields_ = [
("identifier", c_char_p),
@@ -68,6 +93,7 @@ class cServerProfile(Structure):
("default_gateway", c_int),
]
+
class cServerProfiles(Structure):
_fields_ = [
("current", c_int),
@@ -75,10 +101,12 @@ class cServerProfiles(Structure):
("total_profiles", c_size_t),
]
+
class cServer(Structure):
_fields_ = [
("identifier", c_char_p),
("display_name", c_char_p),
+ ("server_type", c_char_p),
("country_code", c_char_p),
("support_contact", POINTER(c_char_p)),
("total_support_contact", c_size_t),
@@ -86,6 +114,7 @@ class cServer(Structure):
("expire_time", c_ulonglong),
]
+
class cServers(Structure):
_fields_ = [
("custom_servers", POINTER(POINTER(cServer))),
@@ -95,6 +124,7 @@ class cServers(Structure):
("secure_internet", POINTER(cServer)),
]
+
class DataError(Structure):
_fields_ = [("data", c_void_p), ("error", c_void_p)]
@@ -104,9 +134,17 @@ VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p)
# Exposed functions
# We have to use c_void_p instead of c_char_p to free it properly
# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error
-lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [c_char_p], c_void_p
-lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [c_char_p, c_char_p], c_void_p
-lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [c_char_p, c_char_p], c_void_p
+lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [
+ c_char_p
+], c_void_p
+lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [
+ c_char_p,
+ c_char_p,
+], c_void_p
+lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [
+ c_char_p,
+ c_char_p,
+], c_void_p
lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [
c_char_p,
c_char_p,
@@ -132,7 +170,7 @@ lib.Register.argtypes, lib.Register.restype = [
lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [
c_char_p
], c_void_p
-lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError
+lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], c_void_p
lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None
lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p
lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p
@@ -150,9 +188,14 @@ lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c
lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p
lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int
lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p
+lib.FreeProfiles.argtypes, lib.FreeProfiles.restype = [c_void_p], None
lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None
lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
-lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [c_void_p], None
+lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [
+ c_void_p
+], None
+lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None
+lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None
lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None
lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int
lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], c_void_p
@@ -186,6 +229,17 @@ def get_ptr_string(ptr: c_void_p) -> str:
return ""
+def get_ptr_list_strings(
+ strings: POINTER(c_char_p), total_strings: c_size_t
+) -> List[str]:
+ if strings:
+ strings_list = []
+ for i in range(total_strings):
+ strings_list.append(strings[i].decode("utf-8"))
+ return strings_list
+ return []
+
+
def get_ptr_error(ptr: c_void_p) -> Optional[WrappedError]:
error_string = get_ptr_string(ptr)
@@ -224,6 +278,7 @@ def get_data_error(data_error: DataError) -> Tuple[str, str]:
def get_bool(boolInt: c_int) -> bool:
return boolInt == 1
+
decode_map = {
c_int: get_bool,
c_void_p: get_error,
diff --git a/wrappers/python/src/discovery.py b/wrappers/python/src/discovery.py
index 80c08cf..a1b87ea 100644
--- a/wrappers/python/src/discovery.py
+++ b/wrappers/python/src/discovery.py
@@ -1,4 +1,4 @@
-from . import lib, cDiscoveryOrganizations
+from . import lib, cDiscoveryOrganizations, cDiscoveryServers, get_ptr_list_strings
from ctypes import cast, POINTER
@@ -9,6 +9,9 @@ class DiscoOrganization:
self.secure_internet_home = secure_internet_home
self.keyword_list = keyword_list
+ def __str__(self):
+ return self.display_name
+
class DiscoOrganizations:
def __init__(self, version, organizations):
@@ -16,6 +19,37 @@ class DiscoOrganizations:
self.organizations = organizations
+class DiscoServer:
+ def __init__(
+ self,
+ authentication_url_template,
+ base_url,
+ country_code,
+ display_name,
+ keyword_list,
+ public_keys,
+ server_type,
+ support_contacts,
+ ):
+ self.authentication_url_template = authentication_url_template
+ self.base_url = base_url
+ self.country_code = country_code
+ self.display_name = display_name
+ self.keyword_list = keyword_list
+ self.public_keys = public_keys
+ self.server_type = server_type
+ self.support_contacts = support_contacts
+
+ def __str__(self):
+ return self.display_name
+
+
+class DiscoServers:
+ def __init__(self, version, servers):
+ self.version = version
+ self.servers = servers
+
+
def get_disco_organization(ptr):
if not ptr:
return None
@@ -28,6 +62,55 @@ def get_disco_organization(ptr):
return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list)
+def get_disco_server(ptr):
+ if not ptr:
+ return None
+
+ current_server = ptr.contents
+ authentication_url_template = current_server.authentication_url_template.decode(
+ "utf-8"
+ )
+ base_url = current_server.base_url.decode("utf-8")
+ country_code = current_server.country_code.decode("utf-8")
+ display_name = current_server.display_name.decode("utf-8")
+ keyword_list = current_server.keyword_list.decode("utf-8")
+ public_keys = get_ptr_list_strings(
+ current_server.public_key_list, current_server.total_public_keys
+ )
+ server_type = current_server.server_type.decode("utf-8")
+ support_contacts = get_ptr_list_strings(
+ current_server.support_contact, current_server.total_support_contact
+ )
+ return DiscoServer(
+ authentication_url_template,
+ base_url,
+ country_code,
+ display_name,
+ keyword_list,
+ public_keys,
+ server_type,
+ support_contacts,
+ )
+
+
+def get_disco_servers(ptr):
+ if ptr:
+ svrs = cast(ptr, POINTER(cDiscoveryServers)).contents
+
+ servers = []
+
+ if svrs.servers:
+ for i in range(svrs.total_servers):
+ current = get_disco_server(svrs.servers[i])
+
+ if current is None:
+ continue
+ servers.append(current)
+ lib.FreeDiscoServers(ptr)
+ return DiscoServers(svrs.version, servers)
+ return None
+
+
def get_disco_organizations(ptr):
if ptr:
orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents
diff --git a/wrappers/python/src/event.py b/wrappers/python/src/event.py
index 0e0f5ae..cf1a9d0 100644
--- a/wrappers/python/src/event.py
+++ b/wrappers/python/src/event.py
@@ -2,7 +2,12 @@ from . import VPNStateChange, get_ptr_string
from enum import Enum
from typing import Callable
from .state import State, StateType
-from .server import get_locations, get_servers
+from .server import (
+ get_locations,
+ get_transition_profiles,
+ get_transition_server,
+ get_servers,
+)
EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback"
@@ -16,6 +21,7 @@ def class_state_transition(state: int, state_type: StateType) -> Callable:
return wrapper
+
def convert_data(state: State, data):
if not data:
return None
@@ -25,6 +31,16 @@ def convert_data(state: State, data):
return get_ptr_string(data)
if state is State.ASK_LOCATION:
return get_locations(data)
+ if state is State.ASK_PROFILE:
+ return get_transition_profiles(data)
+ if state in [
+ State.DISCONNECTED,
+ State.DISCONNECTING,
+ State.CONNECTING,
+ State.CONNECTED,
+ ]:
+ return get_transition_server(data)
+
class EventHandler(object):
def __init__(self):
@@ -80,10 +96,14 @@ class EventHandler(object):
for func in self.handlers[(state, state_type)]:
func(other_state, data)
- def run(self, old_state: int, new_state: int, data: str) -> None:
+ def run(
+ self, old_state: int, new_state: int, data: str, convert: bool = True
+ ) -> None:
# First run leave transitions, then enter
# The state is done when the wait event finishes
- converted = convert_data(new_state, data)
+ converted = data
+ if convert:
+ converted = convert_data(new_state, data)
self.run_state(old_state, new_state, StateType.Leave, converted)
self.run_state(new_state, old_state, StateType.Enter, converted)
self.run_state(new_state, old_state, StateType.Wait, converted)
diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py
index 8425992..e20b84f 100644
--- a/wrappers/python/src/main.py
+++ b/wrappers/python/src/main.py
@@ -1,7 +1,7 @@
from . import lib, VPNStateChange, encode_args, decode_res
from typing import Optional, Tuple
import threading
-from .discovery import get_disco_organizations
+from .discovery import get_disco_organizations, get_disco_servers
from .event import EventHandler
from .state import State, StateType
from .server import get_servers
@@ -88,16 +88,20 @@ class EduVPN(object):
raise Exception(register_err)
def get_disco_servers(self) -> str:
- servers, servers_err = self.go_function(lib.GetDiscoServers)
+ servers = self.go_function_custom_decode(
+ lib.GetDiscoServers, decode_func=get_disco_servers
+ )
- if servers_err:
- raise Exception(servers_err)
+ # if servers_err:
+ # raise Exception(servers_err)
return servers
def get_disco_organizations(self) -> str:
- organizations = self.go_function_custom_decode(lib.GetDiscoOrganizations, decode_func=get_disco_organizations)
- #if organizations_err:
+ organizations = self.go_function_custom_decode(
+ lib.GetDiscoOrganizations, decode_func=get_disco_organizations
+ )
+ # if organizations_err:
# raise Exception(organizations_err)
return organizations
@@ -251,4 +255,6 @@ class EduVPN(object):
return self.go_function(lib.GetSavedServersOLD)
def get_saved_servers_new(self) -> str:
- return self.go_function_custom_decode(lib.GetSavedServersNEW, decode_func=get_servers)
+ return self.go_function_custom_decode(
+ lib.GetSavedServersNEW, decode_func=get_servers
+ )
diff --git a/wrappers/python/src/server.py b/wrappers/python/src/server.py
index b765ede..dce5d51 100644
--- a/wrappers/python/src/server.py
+++ b/wrappers/python/src/server.py
@@ -1,5 +1,6 @@
-from . import lib, cServers, cServerLocations
-from ctypes import cast, POINTER
+from . import lib, cServer, cServers, cServerLocations, cServerProfiles
+from ctypes import cast, POINTER, c_char_p
+from datetime import datetime
class Profile:
@@ -9,61 +10,85 @@ class Profile:
self.default_gateway = default_gateway
def __str__(self):
- return f"Profile: {self.display_name}"
+ return self.display_name
+
+
+class Profiles:
+ def __init__(self, profiles, current):
+ self.profiles = profiles
+ self.current_index = current
+
+ @property
+ def current(self):
+ if self.current_index < len(self.profiles):
+ return self.profiles[self.current_index]
+ return None
class Server:
- def __init__(self, url, display_name, profiles, current_profile, expire_time):
+ def __init__(self, url, display_name, profiles=None, expire_time=0):
self.url = url
self.display_name = display_name
self.profiles = profiles
self.current_profile = None
- if current_profile < len(profiles):
- self.current_profile = profiles[current_profile]
- self.expire_time = expire_time
+ self.expire_time = datetime.fromtimestamp(expire_time)
def __str__(self):
- return f"Server: {self.url}, with current profile: {self.current_profile}"
+ return self.display_name
class InstituteServer(Server):
- def __init__(
- self, url, display_name, support_contact, profiles, current_profile, expire_time
- ):
- super().__init__(url, display_name, profiles, current_profile, expire_time)
+ def __init__(self, url, display_name, support_contact, profiles, expire_time):
+ super().__init__(url, display_name, profiles, expire_time)
self.support_contact = support_contact
- def __str__(self):
- return f"Institute Server: {self.display_name}"
-
class SecureInternetServer(Server):
def __init__(
self,
- url,
+ org_id,
display_name,
support_contact,
profiles,
- current_profile,
expire_time,
country_code,
):
- super().__init__(url, display_name, profiles, current_profile, expire_time)
+ super().__init__(org_id, display_name, profiles, expire_time)
+ self.org_id = org_id
self.support_contact = support_contact
self.country_code = country_code
- def __str__(self):
- return f"Secure Internet Server: {self.display_name} with country {self.country_code}"
-
def get_type_for_str(type_str: str):
- if type_str is "secure_internet":
+ if type_str == "secure_internet":
return SecureInternetServer
- if type_str is "custom_server":
+ if type_str == "custom_server":
return Server
return InstituteServer
+def get_profiles(ptr):
+ if not ptr:
+ return []
+ profiles = []
+ _profiles = ptr.contents
+ current_profile = _profiles.current
+ if not _profiles.profiles:
+ return []
+ for i in range(_profiles.total_profiles):
+ if not _profiles.profiles[i]:
+ continue
+ profile = _profiles.profiles[i].contents
+ profiles.append(
+ Profile(
+ profile.identifier.decode("utf-8"),
+ profile.display_name.decode("utf-8"),
+ profile.default_gateway == 1,
+ )
+ )
+ return Profiles(profiles, current_profile)
+
+
def get_server(ptr, _type=None):
if not ptr:
return None
@@ -79,31 +104,13 @@ def get_server(ptr, _type=None):
support_contact = []
for i in range(current_server.total_support_contact):
support_contact.append(current_server.support_contact[i].decode("utf-8"))
- profiles = []
- if not current_server.profiles:
- return None
-
- _profiles = current_server.profiles.contents
- current_profile = _profiles.current
- for i in range(_profiles.total_profiles):
- if not _profiles.profiles or not _profiles.profiles[i]:
- return None
- profile = _profiles.profiles[i].contents
- profiles.append(
- Profile(
- profile.identifier.decode("utf-8"),
- profile.display_name.decode("utf-8"),
- profile.default_gateway == 1,
- )
- )
-
+ profiles = get_profiles(current_server.profiles)
if _type is SecureInternetServer:
return SecureInternetServer(
identifier,
display_name,
support_contact,
profiles,
- current_profile,
current_server.expire_time,
current_server.country_code.decode("utf-8"),
)
@@ -113,12 +120,21 @@ def get_server(ptr, _type=None):
display_name,
support_contact,
profiles,
- current_profile,
current_server.expire_time,
)
- return Server(
- identifier, display_name, profiles, current_profile, current_server.expire_time
- )
+ return Server(identifier, display_name, profiles, current_server.expire_time)
+
+
+def get_transition_server(ptr):
+ server = get_server(cast(ptr, POINTER(cServer)))
+ lib.FreeServer(ptr)
+ return server
+
+
+def get_transition_profiles(ptr):
+ profiles = get_profiles(cast(ptr, POINTER(cServerProfiles)))
+ lib.FreeProfiles(ptr)
+ return profiles
def get_servers(ptr):
@@ -147,6 +163,7 @@ def get_servers(ptr):
return returned
return None
+
def get_locations(ptr):
if ptr:
locations = cast(ptr, POINTER(cServerLocations)).contents