diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-16 10:46:28 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-16 10:46:28 +0200 |
| commit | 4bf1273c3f446ac3195fb700ec41c7cae7d20ac9 (patch) | |
| tree | cec8d9e405b7d6786023ca9b921a6f0473d28a71 | |
| parent | 02db081c85e56e6472c2f39e6a623fa4cdf359c4 (diff) | |
Discovery: Expose c types
| -rw-r--r-- | exports/c/disco.h | 19 | ||||
| -rw-r--r-- | exports/c/servers.h | 1 | ||||
| -rw-r--r-- | exports/disco.go | 98 | ||||
| -rw-r--r-- | exports/exports.go | 21 | ||||
| -rw-r--r-- | exports/servers.go | 65 | ||||
| -rw-r--r-- | internal/discovery/discovery.go | 8 | ||||
| -rw-r--r-- | internal/server/common.go | 103 | ||||
| -rw-r--r-- | internal/types/server.go | 1 | ||||
| -rw-r--r-- | state.go | 41 | ||||
| -rw-r--r-- | wrappers/python/src/__init__.py | 73 | ||||
| -rw-r--r-- | wrappers/python/src/discovery.py | 85 | ||||
| -rw-r--r-- | wrappers/python/src/event.py | 26 | ||||
| -rw-r--r-- | wrappers/python/src/main.py | 20 | ||||
| -rw-r--r-- | wrappers/python/src/server.py | 109 |
14 files changed, 455 insertions, 215 deletions
diff --git a/exports/c/disco.h b/exports/c/disco.h index 41d59fa..8fa07a4 100644 --- a/exports/c/disco.h +++ b/exports/c/disco.h @@ -1,6 +1,25 @@ // for size_t #include <stddef.h> +typedef struct discoveryServer { + const char* authentication_url_template; + const char* base_url; + const char* country_code; + const char* display_name; + const char* keyword_list; + const char** public_key_list; + size_t total_public_keys; + const char* server_type; + const char** support_contact; + size_t total_support_contact; +} discoveryServer; + +typedef struct discoveryServers { + unsigned long long int version; + discoveryServer** servers; + size_t total_servers; +} discoveryServers; + typedef struct discoveryOrganization { const char* display_name; const char* org_id; diff --git a/exports/c/servers.h b/exports/c/servers.h index 39e52a2..1b6cca9 100644 --- a/exports/c/servers.h +++ b/exports/c/servers.h @@ -26,6 +26,7 @@ typedef struct serverLocations { typedef struct server { const char* identifier; const char* display_name; + const char* server_type; const char* country_code; const char** support_contact; size_t total_support_contact; diff --git a/exports/disco.go b/exports/disco.go index 9ee2af9..a08f8ba 100644 --- a/exports/disco.go +++ b/exports/disco.go @@ -33,9 +33,8 @@ func getCPtrDiscoOrganizations( organizations *types.DiscoveryOrganizations, ) (C.size_t, **C.discoveryOrganization) { totalOrganizations := C.size_t(len(organizations.List)) - var organizationsPtr **C.discoveryOrganization if totalOrganizations > 0 { - organizationsPtr = (**C.discoveryOrganization)( + organizationsPtr := (**C.discoveryOrganization)( C.malloc(totalOrganizations * C.size_t(unsafe.Sizeof(uintptr(0)))), ) cOrganizations := (*[1<<30 - 1]*C.discoveryOrganization)(unsafe.Pointer(organizationsPtr))[:totalOrganizations:totalOrganizations] @@ -45,8 +44,52 @@ func getCPtrDiscoOrganizations( cOrganizations[index] = cOrganization index += 1 } + return totalOrganizations, organizationsPtr } - return totalOrganizations, organizationsPtr + return 0, nil +} + +func getCPtrDiscoServer( + state *eduvpn.VPNState, + server *types.DiscoveryServer, +) *C.discoveryServer { + returnedStruct := (*C.discoveryServer)( + C.malloc(C.size_t(unsafe.Sizeof(C.discoveryServer{}))), + ) + returnedStruct.authentication_url_template = C.CString(server.AuthenticationURLTemplate) + returnedStruct.base_url = C.CString(server.BaseURL) + returnedStruct.country_code = C.CString(server.CountryCode) + returnedStruct.display_name = C.CString(state.GetTranslated(server.DisplayName)) + returnedStruct.keyword_list = C.CString(state.GetTranslated(server.KeywordList)) + returnedStruct.total_public_keys, returnedStruct.public_key_list = getCPtrListStrings( + server.PublicKeyList, + ) + returnedStruct.server_type = C.CString(server.Type) + returnedStruct.total_support_contact, returnedStruct.support_contact = getCPtrListStrings( + server.SupportContact, + ) + return returnedStruct +} + +func getCPtrDiscoServers( + state *eduvpn.VPNState, + servers *types.DiscoveryServers, +) (C.size_t, **C.discoveryServer) { + totalServers := C.size_t(len(servers.List)) + if totalServers > 0 { + serversPtr := (**C.discoveryServer)( + C.malloc(totalServers * C.size_t(unsafe.Sizeof(uintptr(0)))), + ) + cServers := (*[1<<30 - 1]*C.discoveryServer)(unsafe.Pointer(serversPtr)) + index := 0 + for _, server := range servers.List { + cServer := getCPtrDiscoServer(state, &server) + cServers[index] = cServer + index += 1 + } + return totalServers, serversPtr + } + return 0, nil } func freeDiscoOrganization(cOrganization *C.discoveryOrganization) { @@ -57,6 +100,30 @@ func freeDiscoOrganization(cOrganization *C.discoveryOrganization) { C.free(unsafe.Pointer(cOrganization)) } +func freeDiscoServer(cServer *C.discoveryServer) { + C.free(unsafe.Pointer(cServer.authentication_url_template)) + C.free(unsafe.Pointer(cServer.base_url)) + C.free(unsafe.Pointer(cServer.country_code)) + C.free(unsafe.Pointer(cServer.display_name)) + C.free(unsafe.Pointer(cServer.keyword_list)) + freeCListStrings(cServer.public_key_list, cServer.total_public_keys) + C.free(unsafe.Pointer(cServer.server_type)) + freeCListStrings(cServer.support_contact, cServer.total_support_contact) + C.free(unsafe.Pointer(cServer)) +} + +//export FreeDiscoServers +func FreeDiscoServers(cServers *C.discoveryServers) { + if cServers.total_servers > 0 { + servers := (*[1<<30 - 1]*C.discoveryServer)(unsafe.Pointer(cServers.servers))[:cServers.total_servers:cServers.total_servers] + for i := C.size_t(0); i < cServers.total_servers; i++ { + freeDiscoServer(servers[i]) + } + C.free(unsafe.Pointer(cServers.servers)) + } + C.free(unsafe.Pointer(cServers)) +} + //export FreeDiscoOrganizations func FreeDiscoOrganizations(cOrganizations *C.discoveryOrganizations) { if cOrganizations.total_organizations > 0 { @@ -69,6 +136,31 @@ func FreeDiscoOrganizations(cOrganizations *C.discoveryOrganizations) { C.free(unsafe.Pointer(cOrganizations)) } +//export GetDiscoServers +func GetDiscoServers(name *C.char) *C.discoveryServers { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + // TODO + if stateErr != nil { + panic(stateErr) + } + servers, serversErr := state.GetDiscoServers() + // TODO + if serversErr != nil { + panic(serversErr) + } + + returnedStruct := (*C.discoveryServers)( + C.malloc(C.size_t(unsafe.Sizeof(C.discoveryServers{}))), + ) + returnedStruct.version = C.ulonglong(servers.Version) + returnedStruct.total_servers, returnedStruct.servers = getCPtrDiscoServers( + state, + servers, + ) + return returnedStruct +} + //export GetDiscoOrganizations func GetDiscoOrganizations(name *C.char) *C.discoveryOrganizations { nameStr := C.GoString(name) diff --git a/exports/exports.go b/exports/exports.go index 797a192..0144500 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -38,6 +38,16 @@ func GetStateData( } case eduvpn.STATE_ASK_LOCATION: return (unsafe.Pointer)(getTransitionSecureLocations(data)) + case eduvpn.STATE_ASK_PROFILE: + return (unsafe.Pointer)(getTransitionProfiles(data)) + case eduvpn.STATE_DISCONNECTED: + return (unsafe.Pointer)(getTransitionServer(state, data)) + case eduvpn.STATE_DISCONNECTING: + return (unsafe.Pointer)(getTransitionServer(state, data)) + case eduvpn.STATE_CONNECTING: + return (unsafe.Pointer)(getTransitionServer(state, data)) + case eduvpn.STATE_CONNECTED: + return (unsafe.Pointer)(getTransitionServer(state, data)) default: return nil } @@ -229,17 +239,6 @@ func GetConfigCustomServer(name *C.char, url *C.char, forceTCP C.int) (*C.char, return getConfigJSON(config, configType), C.CString(ErrorToString(configErr)) } -//export GetDiscoServers -func GetDiscoServers(name *C.char) (*C.char, *C.char) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return nil, C.CString(ErrorToString(stateErr)) - } - servers, serversErr := state.GetDiscoServers() - return C.CString(servers), C.CString(ErrorToString(serversErr)) -} - //export SetProfileID func SetProfileID(name *C.char, data *C.char) *C.char { nameStr := C.GoString(name) diff --git a/exports/servers.go b/exports/servers.go index f92c08e..d33ac1f 100644 --- a/exports/servers.go +++ b/exports/servers.go @@ -51,7 +51,7 @@ func getCPtrProfiles(serverProfiles *server.ServerProfileInfo) *C.serverProfiles profiles[index] = getCPtrProfile(&profile) index += 1 } - // TODO: DO CURRENT PROFILE + cProfiles.current = C.int(serverProfiles.GetCurrentProfileIndex()) cProfiles.profiles = (**C.serverProfile)(profilesPtr) } return cProfiles @@ -59,7 +59,8 @@ func getCPtrProfiles(serverProfiles *server.ServerProfileInfo) *C.serverProfiles // Free the profiles by looping through them if there are any // Also free the pointer itself -func freeCProfiles(profiles *C.serverProfiles) { +//export FreeProfiles +func FreeProfiles(profiles *C.serverProfiles) { // We should only free the profiles if we have them (which we should) if profiles.total_profiles > 0 { // Convert it to a go slice @@ -119,12 +120,21 @@ func freeCListStrings(allStrings **C.char, totalStrings C.size_t) { // Function for getting the server, // It gets the main state as a pointer as we need to convert some string maps to localized strings // It gets the base information for a server as well -func getServer(state *eduvpn.VPNState, base *eduvpn.VPNServerBase) *C.server { +func getCPtrServer(state *eduvpn.VPNState, base *eduvpn.VPNServerBase) *C.server { // Allocation using malloc and the size of the struct server := (*C.server)(C.malloc(C.size_t(unsafe.Sizeof(C.server{})))) // String allocation and translate the display name - server.identifier = C.CString(base.URL) + identifier := base.URL + countryCode := "" + if base.Type == "secure_internet" { + identifier = state.Servers.SecureInternetHomeServer.HomeOrganizationID + countryCode = state.Servers.SecureInternetHomeServer.CurrentLocation + } + + server.identifier = C.CString(identifier) server.display_name = C.CString(state.GetTranslated(base.DisplayName)) + server.country_code = C.CString(countryCode) + server.server_type = C.CString(base.Type) // Call the helper to get the list of support contacts server.total_support_contact, server.support_contact = getCPtrListStrings( base.SupportContact, @@ -133,22 +143,26 @@ func getServer(state *eduvpn.VPNState, base *eduvpn.VPNServerBase) *C.server { // No endtime is given if we get servers when it has been partially initialised if base.EndTime.IsZero() { server.expire_time = C.ulonglong(0) + } else { + // The expire time should be stored as an unsigned long long in unix itme + server.expire_time = C.ulonglong(base.EndTime.Unix()) } - // The expire time should be stored as an unsigned long long in unix itme - server.expire_time = C.ulonglong(base.EndTime.Unix()) return server } // Function for freeing a single server // Gets the pointer to C struct -func freeServer(info *C.server) { +//export FreeServer +func FreeServer(info *C.server) { // Free strings C.free(unsafe.Pointer(info.identifier)) C.free(unsafe.Pointer(info.display_name)) + C.free(unsafe.Pointer(info.country_code)) + C.free(unsafe.Pointer(info.server_type)) // Free arrays freeCListStrings(info.support_contact, info.total_support_contact) - freeCProfiles(info.profiles) + FreeProfiles(info.profiles) // Free the struct itself C.free(unsafe.Pointer(info)) @@ -166,10 +180,11 @@ func getCPtrServers( servers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(serversPtr))[:totalServers:totalServers] index := 0 for _, server := range serverMap { - cServer := getServer(state, &server.Base) + cServer := getCPtrServer(state, &server.Base) servers[index] = cServer index += 1 } + return totalServers, serversPtr } return C.size_t(0), nil } @@ -182,7 +197,7 @@ func FreeServers(cServers *C.servers) { if cServers.total_custom > 0 { customServers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(cServers.custom_servers))[:cServers.total_custom:cServers.total_custom] for i := C.size_t(0); i < cServers.total_custom; i++ { - freeServer(customServers[i]) + FreeServer(customServers[i]) } C.free(unsafe.Pointer(cServers.custom_servers)) } @@ -191,14 +206,13 @@ func FreeServers(cServers *C.servers) { instituteServers := (*[1<<30 - 1]*C.server)(unsafe.Pointer(cServers.institute_servers))[:cServers.total_institute:cServers.total_institute] for i := C.size_t(0); i < cServers.total_institute; i++ { - freeServer(instituteServers[i]) + FreeServer(instituteServers[i]) } C.free(unsafe.Pointer(cServers.institute_servers)) } // Free the secure internet server if there is one if cServers.secure_internet_server != nil { - C.free(unsafe.Pointer(cServers.secure_internet_server.country_code)) - freeServer(cServers.secure_internet_server) + FreeServer(cServers.secure_internet_server) } // Free the structure itself C.free(unsafe.Pointer(cServers)) @@ -219,7 +233,7 @@ func getSavedServersWithOptions(state *eduvpn.VPNState, servers *server.Servers) secureInternetBase, secureInternetBaseErr := servers.SecureInternetHomeServer.GetBase() if secureInternetBaseErr == nil && secureInternetBase != nil { // FIXME: log error? - secureServerPtr = getServer(state, secureInternetBase) + secureServerPtr = getCPtrServer(state, secureInternetBase) // Give a new identifier C.free(unsafe.Pointer(secureServerPtr.identifier)) secureServerPtr.identifier = C.CString(servers.SecureInternetHomeServer.HomeOrganizationID) @@ -259,11 +273,11 @@ func getTransitionDataServers(state *eduvpn.VPNState, data interface{}) *C.serve //export FreeSecureLocations func FreeSecureLocations(locations *C.serverLocations) { - freeCListStrings(locations.locations, locations.total_locations); + freeCListStrings(locations.locations, locations.total_locations) C.free(unsafe.Pointer(locations)) } -func getTransitionSecureLocations(data interface{}) (*C.serverLocations) { +func getTransitionSecureLocations(data interface{}) *C.serverLocations { if locations, ok := data.([]string); ok { returnedStruct := (*C.serverLocations)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{})))) returnedStruct.total_locations, returnedStruct.locations = getCPtrListStrings(locations) @@ -271,3 +285,22 @@ func getTransitionSecureLocations(data interface{}) (*C.serverLocations) { } return nil } + +func getTransitionProfiles(data interface{}) *C.serverProfiles { + if profiles, ok := data.(*server.ServerProfileInfo); ok { + return getCPtrProfiles(profiles) + } + return nil +} + +func getTransitionServer(state *eduvpn.VPNState, data interface{}) *C.server { + if server, ok := data.(server.Server); ok { + base, baseErr := server.GetBase() + if baseErr != nil { + // TODO: LOG + return nil + } + return getCPtrServer(state, base) + } + return nil +} diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 21125cb..b3b438c 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -182,15 +182,15 @@ func (discovery *Discovery) GetOrganizationsList() (*types.DiscoveryOrganization } // Get the server list -func (discovery *Discovery) GetServersList() (string, error) { +func (discovery *Discovery) GetServersList() (*types.DiscoveryServers, error) { if !discovery.DetermineServersUpdate() { - return discovery.Servers.RawString, nil + return &discovery.Servers, nil } file := "server_list.json" body, bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers) if bodyErr != nil { // Return previous with an error - return discovery.Servers.RawString, &types.WrappedErrorMessage{ + return &discovery.Servers, &types.WrappedErrorMessage{ Message: "failed getting servers in Discovery", Err: bodyErr, } @@ -198,7 +198,7 @@ func (discovery *Discovery) GetServersList() (string, error) { // Update servers timestamp discovery.Servers.RawString = body discovery.Servers.Timestamp = util.GetCurrentTime() - return discovery.Servers.RawString, nil + return &discovery.Servers, nil } type GetOrgByIDNotFoundError struct { diff --git a/internal/server/common.go b/internal/server/common.go index 64b8079..9c941cf 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -64,6 +64,18 @@ type ServerProfileInfo struct { } `json:"info"` } +func (info ServerProfileInfo) GetCurrentProfileIndex() int { + index := 0 + for _, profile := range info.Info.ProfileList { + if profile.ID == info.Current { + return index + } + index += 1 + } + // Default is 'first' profile + return 0 +} + type ServerEndpointList struct { API string `json:"api_endpoint"` Authorization string `json:"authorization_endpoint"` @@ -118,97 +130,6 @@ func (servers *Servers) GetCurrentServer() (Server, error) { return server, nil } -type ServersConfiguredScreen struct { - CustomServers []ServerInfoScreen `json:"custom_servers"` - InstituteAccessServers []ServerInfoScreen `json:"institute_access_servers"` - SecureInternetServer *ServerInfoScreen `json:"secure_internet_server"` -} - -type ServerInfoScreen struct { - Identifier string `json:"identifier"` - DisplayName map[string]string `json:"display_name"` - CountryCode string `json:"country_code,omitempty"` - SupportContact []string `json:"support_contact"` - Profiles ServerProfileInfo `json:"profiles"` - ExpireTime int64 `json:"expire_time"` - Type string `json:"server_type"` -} - -func getServerInfoScreen(base ServerBase) ServerInfoScreen { - serverInfoScreen := ServerInfoScreen{} - serverInfoScreen.Identifier = base.URL - serverInfoScreen.DisplayName = base.DisplayName - serverInfoScreen.SupportContact = base.SupportContact - serverInfoScreen.Profiles = base.Profiles - - // If we still have the default end time, return 0 - // Such that clients will still be able to parse it correctly - if base.EndTime.IsZero() { - serverInfoScreen.ExpireTime = 0 - } else { - serverInfoScreen.ExpireTime = base.EndTime.Unix() - } - serverInfoScreen.Type = base.Type - - return serverInfoScreen -} - -func (servers *Servers) GetServersConfigured() *ServersConfiguredScreen { - customServersInfo := []ServerInfoScreen{} - instituteServersInfo := []ServerInfoScreen{} - var secureInternetServerInfo *ServerInfoScreen = nil - - for _, server := range servers.CustomServers.Map { - serverInfoScreen := getServerInfoScreen(server.Base) - customServersInfo = append(customServersInfo, serverInfoScreen) - } - - for _, server := range servers.InstituteServers.Map { - serverInfoScreen := getServerInfoScreen(server.Base) - instituteServersInfo = append(instituteServersInfo, serverInfoScreen) - } - - secureInternetBase, secureInternetBaseErr := servers.SecureInternetHomeServer.GetBase() - - if secureInternetBaseErr == nil && secureInternetBase != nil { - // FIXME: log error? - secureInternetServerInfoReturned := getServerInfoScreen(*secureInternetBase) - secureInternetServerInfo = &secureInternetServerInfoReturned - secureInternetServerInfo.Identifier = servers.SecureInternetHomeServer.HomeOrganizationID - secureInternetServerInfo.CountryCode = servers.SecureInternetHomeServer.CurrentLocation - } - - return &ServersConfiguredScreen{ - CustomServers: customServersInfo, - InstituteAccessServers: instituteServersInfo, - SecureInternetServer: secureInternetServerInfo, - } -} - -func (servers *Servers) GetCurrentServerInfo() (*ServerInfoScreen, error) { - errorMessage := "failed getting current server info" - - currentServer, currentServerErr := servers.GetCurrentServer() - if currentServerErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} - } - - base, baseErr := currentServer.GetBase() - - if baseErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} - } - - serverInfoScreen := getServerInfoScreen(*base) - - if servers.IsType == SecureInternetServerType { - serverInfoScreen.Identifier = servers.SecureInternetHomeServer.HomeOrganizationID - serverInfoScreen.CountryCode = servers.SecureInternetHomeServer.CurrentLocation - } - - return &serverInfoScreen, nil -} - func (servers *Servers) addInstituteAndCustom( discoServer *types.DiscoveryServer, isCustom bool, diff --git a/internal/types/server.go b/internal/types/server.go index 33a6e9c..48f94fb 100644 --- a/internal/types/server.go +++ b/internal/types/server.go @@ -62,6 +62,7 @@ type DiscoveryServer struct { BaseURL string `json:"base_url"` CountryCode string `json:"country_code"` DisplayName DiscoMapOrString `json:"display_name,omitempty"` + KeywordList DiscoMapOrString `json:"keyword_list"` PublicKeyList []string `json:"public_key_list"` Type string `json:"server_type"` SupportContact []string `json:"support_contact"` @@ -15,7 +15,6 @@ import ( ) type ( - ServerInfo = server.ServerInfoScreen VPNServerBase = server.ServerBase ) @@ -42,10 +41,6 @@ type VPNState struct { Debug bool `json:"-"` } -func (state *VPNState) GetSavedServers() *server.ServersConfiguredScreen { - return state.Servers.GetServersConfigured() -} - // Register initializes the state with the following parameters: // - name: the name of the client // - directory: the directory where the config files are stored. Absolute or relative @@ -95,7 +90,7 @@ func (state *VPNState) Register( _, currentServerErr := state.Servers.GetCurrentServer() // Only actually return the error if we have no disco servers and no current server - if discoServersErr != nil && discoServers == "" && currentServerErr != nil { + if discoServersErr != nil && discoServers.Version == 0 && currentServerErr != nil { state.Logger.Error( fmt.Sprintf( "No configured servers, discovery servers is empty and no servers with error: %s", @@ -263,7 +258,6 @@ func (state *VPNState) getConfig( // Signal the server display info state.FSM.GoTransitionWithData(STATE_DISCONNECTED, currentServer, false) - // Save the config state.Config.Save(&state) @@ -621,13 +615,13 @@ func (state *VPNState) GetDiscoOrganizations() (*types.DiscoveryOrganizations, e return orgs, nil } -func (state *VPNState) GetDiscoServers() (string, error) { +func (state *VPNState) GetDiscoServers() (*types.DiscoveryServers, error) { servers, serversErr := state.Discovery.GetServersList() if serversErr != nil { state.Logger.Warning( fmt.Sprintf("Failed getting discovery servers, Err: %s", GetErrorTraceback(serversErr)), ) - return "", &types.WrappedErrorMessage{ + return nil, &types.WrappedErrorMessage{ Message: "failed getting discovery servers list", Err: serversErr, } @@ -682,19 +676,6 @@ func (state *VPNState) SetSearchServer() error { return nil } -func (state *VPNState) getServerInfoData() *server.ServerInfoScreen { - info, infoErr := state.Servers.GetCurrentServerInfo() - if infoErr != nil { - state.Logger.Error( - fmt.Sprintf( - "Failed getting server info data with error: %s", - GetErrorTraceback(infoErr), - ), - ) - } - return info -} - func (state *VPNState) SetConnected() error { errorMessage := "failed to set connected" if state.InFSMState(STATE_CONNECTED) { @@ -734,6 +715,7 @@ func (state *VPNState) SetConnected() error { } func (state *VPNState) SetConnecting() error { + errorMessage := "failed to set connecting" if state.InFSMState(STATE_CONNECTING) { // already loading connection, show no error state.Logger.Warning("Already connecting") @@ -747,7 +729,7 @@ func (state *VPNState) SetConnecting() error { ), ) return &types.WrappedErrorMessage{ - Message: "failed to set connecting", + Message: errorMessage, Err: FSMWrongStateTransitionError{ Got: state.FSM.Current, Want: STATE_CONNECTING, @@ -755,7 +737,18 @@ func (state *VPNState) SetConnecting() error { } } - state.FSM.GoTransition(STATE_CONNECTING) + currentServer, currentServerErr := state.Servers.GetCurrentServer() + if currentServerErr != nil { + state.Logger.Warning( + fmt.Sprintf( + "Failed setting connecting, cannot get current server with error: %s", + GetErrorTraceback(currentServerErr), + ), + ) + return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr} + } + + state.FSM.GoTransitionWithData(STATE_CONNECTING, currentServer, false) return nil } diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py index ece4b46..ca07143 100644 --- a/wrappers/python/src/__init__.py +++ b/wrappers/python/src/__init__.py @@ -5,6 +5,7 @@ import pathlib import platform from typing import Tuple, Optional import json +from typing import List _lib_prefixes = defaultdict( lambda: "lib", @@ -40,11 +41,10 @@ class ErrorLevel(Enum): ERR_OTHER = 0 ERR_INFO = 1 + class cServerLocations(Structure): - _fields_ = [ - ("locations", POINTER(c_char_p)), - ("total_locations", c_size_t) - ] + _fields_ = [("locations", POINTER(c_char_p)), ("total_locations", c_size_t)] + class cDiscoveryOrganization(Structure): _fields_ = [ @@ -54,6 +54,7 @@ class cDiscoveryOrganization(Structure): ("keyword_list", c_char_p), ] + class cDiscoveryOrganizations(Structure): _fields_ = [ ("version", c_ulonglong), @@ -61,6 +62,30 @@ class cDiscoveryOrganizations(Structure): ("total_organizations", c_size_t), ] + +class cDiscoveryServer(Structure): + _fields_ = [ + ("authentication_url_template", c_char_p), + ("base_url", c_char_p), + ("country_code", c_char_p), + ("display_name", c_char_p), + ("keyword_list", c_char_p), + ("public_key_list", POINTER(c_char_p)), + ("total_public_keys", c_size_t), + ("server_type", c_char_p), + ("support_contact", POINTER(c_char_p)), + ("total_support_contact", c_size_t), + ] + + +class cDiscoveryServers(Structure): + _fields_ = [ + ("version", c_ulonglong), + ("servers", POINTER(POINTER(cDiscoveryServer))), + ("total_servers", c_size_t), + ] + + class cServerProfile(Structure): _fields_ = [ ("identifier", c_char_p), @@ -68,6 +93,7 @@ class cServerProfile(Structure): ("default_gateway", c_int), ] + class cServerProfiles(Structure): _fields_ = [ ("current", c_int), @@ -75,10 +101,12 @@ class cServerProfiles(Structure): ("total_profiles", c_size_t), ] + class cServer(Structure): _fields_ = [ ("identifier", c_char_p), ("display_name", c_char_p), + ("server_type", c_char_p), ("country_code", c_char_p), ("support_contact", POINTER(c_char_p)), ("total_support_contact", c_size_t), @@ -86,6 +114,7 @@ class cServer(Structure): ("expire_time", c_ulonglong), ] + class cServers(Structure): _fields_ = [ ("custom_servers", POINTER(POINTER(cServer))), @@ -95,6 +124,7 @@ class cServers(Structure): ("secure_internet", POINTER(cServer)), ] + class DataError(Structure): _fields_ = [("data", c_void_p), ("error", c_void_p)] @@ -104,9 +134,17 @@ VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p) # Exposed functions # We have to use c_void_p instead of c_char_p to free it properly # See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error -lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [c_char_p], c_void_p -lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [c_char_p, c_char_p], c_void_p -lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [c_char_p, c_char_p], c_void_p +lib.RemoveSecureInternet.argtypes, lib.RemoveSecureInternet.restype = [ + c_char_p +], c_void_p +lib.RemoveInstituteAccess.argtypes, lib.RemoveInstituteAccess.restype = [ + c_char_p, + c_char_p, +], c_void_p +lib.RemoveCustomServer.argtypes, lib.RemoveCustomServer.restype = [ + c_char_p, + c_char_p, +], c_void_p lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [ c_char_p, c_char_p, @@ -132,7 +170,7 @@ lib.Register.argtypes, lib.Register.restype = [ lib.GetDiscoOrganizations.argtypes, lib.GetDiscoOrganizations.restype = [ c_char_p ], c_void_p -lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], DataError +lib.GetDiscoServers.argtypes, lib.GetDiscoServers.restype = [c_char_p], c_void_p lib.GoBack.argtypes, lib.GoBack.restype = [c_char_p], None lib.CancelOAuth.argtypes, lib.CancelOAuth.restype = [c_char_p], c_void_p lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p, c_char_p], c_void_p @@ -150,9 +188,14 @@ lib.SetDisconnected.argtypes, lib.SetDisconnected.restype = [c_char_p, c_int], c lib.SetSearchServer.argtypes, lib.SetSearchServer.restype = [c_char_p], c_void_p lib.ShouldRenewButton.argtypes, lib.ShouldRenewButton.restype = [], int lib.RenewSession.argtypes, lib.RenewSession.restype = [c_char_p], c_void_p +lib.FreeProfiles.argtypes, lib.FreeProfiles.restype = [c_void_p], None lib.FreeSecureLocations.argtypes, lib.FreeSecureLocations.restype = [c_void_p], None lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None -lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [c_void_p], None +lib.FreeDiscoOrganizations.argtypes, lib.FreeDiscoOrganizations.restype = [ + c_void_p +], None +lib.FreeDiscoServers.argtypes, lib.FreeDiscoServers.restype = [c_void_p], None +lib.FreeServer.argtypes, lib.FreeServer.restype = [c_void_p], None lib.FreeServers.argtypes, lib.FreeServers.restype = [c_void_p], None lib.InFSMState.argtypes, lib.InFSMState.restype = [c_void_p, c_int], int lib.GetSavedServers.argtypes, lib.GetSavedServers.restype = [c_char_p], c_void_p @@ -186,6 +229,17 @@ def get_ptr_string(ptr: c_void_p) -> str: return "" +def get_ptr_list_strings( + strings: POINTER(c_char_p), total_strings: c_size_t +) -> List[str]: + if strings: + strings_list = [] + for i in range(total_strings): + strings_list.append(strings[i].decode("utf-8")) + return strings_list + return [] + + def get_ptr_error(ptr: c_void_p) -> Optional[WrappedError]: error_string = get_ptr_string(ptr) @@ -224,6 +278,7 @@ def get_data_error(data_error: DataError) -> Tuple[str, str]: def get_bool(boolInt: c_int) -> bool: return boolInt == 1 + decode_map = { c_int: get_bool, c_void_p: get_error, diff --git a/wrappers/python/src/discovery.py b/wrappers/python/src/discovery.py index 80c08cf..a1b87ea 100644 --- a/wrappers/python/src/discovery.py +++ b/wrappers/python/src/discovery.py @@ -1,4 +1,4 @@ -from . import lib, cDiscoveryOrganizations +from . import lib, cDiscoveryOrganizations, cDiscoveryServers, get_ptr_list_strings from ctypes import cast, POINTER @@ -9,6 +9,9 @@ class DiscoOrganization: self.secure_internet_home = secure_internet_home self.keyword_list = keyword_list + def __str__(self): + return self.display_name + class DiscoOrganizations: def __init__(self, version, organizations): @@ -16,6 +19,37 @@ class DiscoOrganizations: self.organizations = organizations +class DiscoServer: + def __init__( + self, + authentication_url_template, + base_url, + country_code, + display_name, + keyword_list, + public_keys, + server_type, + support_contacts, + ): + self.authentication_url_template = authentication_url_template + self.base_url = base_url + self.country_code = country_code + self.display_name = display_name + self.keyword_list = keyword_list + self.public_keys = public_keys + self.server_type = server_type + self.support_contacts = support_contacts + + def __str__(self): + return self.display_name + + +class DiscoServers: + def __init__(self, version, servers): + self.version = version + self.servers = servers + + def get_disco_organization(ptr): if not ptr: return None @@ -28,6 +62,55 @@ def get_disco_organization(ptr): return DiscoOrganization(display_name, org_id, secure_internet_home, keyword_list) +def get_disco_server(ptr): + if not ptr: + return None + + current_server = ptr.contents + authentication_url_template = current_server.authentication_url_template.decode( + "utf-8" + ) + base_url = current_server.base_url.decode("utf-8") + country_code = current_server.country_code.decode("utf-8") + display_name = current_server.display_name.decode("utf-8") + keyword_list = current_server.keyword_list.decode("utf-8") + public_keys = get_ptr_list_strings( + current_server.public_key_list, current_server.total_public_keys + ) + server_type = current_server.server_type.decode("utf-8") + support_contacts = get_ptr_list_strings( + current_server.support_contact, current_server.total_support_contact + ) + return DiscoServer( + authentication_url_template, + base_url, + country_code, + display_name, + keyword_list, + public_keys, + server_type, + support_contacts, + ) + + +def get_disco_servers(ptr): + if ptr: + svrs = cast(ptr, POINTER(cDiscoveryServers)).contents + + servers = [] + + if svrs.servers: + for i in range(svrs.total_servers): + current = get_disco_server(svrs.servers[i]) + + if current is None: + continue + servers.append(current) + lib.FreeDiscoServers(ptr) + return DiscoServers(svrs.version, servers) + return None + + def get_disco_organizations(ptr): if ptr: orgs = cast(ptr, POINTER(cDiscoveryOrganizations)).contents diff --git a/wrappers/python/src/event.py b/wrappers/python/src/event.py index 0e0f5ae..cf1a9d0 100644 --- a/wrappers/python/src/event.py +++ b/wrappers/python/src/event.py @@ -2,7 +2,12 @@ from . import VPNStateChange, get_ptr_string from enum import Enum from typing import Callable from .state import State, StateType -from .server import get_locations, get_servers +from .server import ( + get_locations, + get_transition_profiles, + get_transition_server, + get_servers, +) EDUVPN_CALLBACK_PROPERTY = "_eduvpn_property_callback" @@ -16,6 +21,7 @@ def class_state_transition(state: int, state_type: StateType) -> Callable: return wrapper + def convert_data(state: State, data): if not data: return None @@ -25,6 +31,16 @@ def convert_data(state: State, data): return get_ptr_string(data) if state is State.ASK_LOCATION: return get_locations(data) + if state is State.ASK_PROFILE: + return get_transition_profiles(data) + if state in [ + State.DISCONNECTED, + State.DISCONNECTING, + State.CONNECTING, + State.CONNECTED, + ]: + return get_transition_server(data) + class EventHandler(object): def __init__(self): @@ -80,10 +96,14 @@ class EventHandler(object): for func in self.handlers[(state, state_type)]: func(other_state, data) - def run(self, old_state: int, new_state: int, data: str) -> None: + def run( + self, old_state: int, new_state: int, data: str, convert: bool = True + ) -> None: # First run leave transitions, then enter # The state is done when the wait event finishes - converted = convert_data(new_state, data) + converted = data + if convert: + converted = convert_data(new_state, data) self.run_state(old_state, new_state, StateType.Leave, converted) self.run_state(new_state, old_state, StateType.Enter, converted) self.run_state(new_state, old_state, StateType.Wait, converted) diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py index 8425992..e20b84f 100644 --- a/wrappers/python/src/main.py +++ b/wrappers/python/src/main.py @@ -1,7 +1,7 @@ from . import lib, VPNStateChange, encode_args, decode_res from typing import Optional, Tuple import threading -from .discovery import get_disco_organizations +from .discovery import get_disco_organizations, get_disco_servers from .event import EventHandler from .state import State, StateType from .server import get_servers @@ -88,16 +88,20 @@ class EduVPN(object): raise Exception(register_err) def get_disco_servers(self) -> str: - servers, servers_err = self.go_function(lib.GetDiscoServers) + servers = self.go_function_custom_decode( + lib.GetDiscoServers, decode_func=get_disco_servers + ) - if servers_err: - raise Exception(servers_err) + # if servers_err: + # raise Exception(servers_err) return servers def get_disco_organizations(self) -> str: - organizations = self.go_function_custom_decode(lib.GetDiscoOrganizations, decode_func=get_disco_organizations) - #if organizations_err: + organizations = self.go_function_custom_decode( + lib.GetDiscoOrganizations, decode_func=get_disco_organizations + ) + # if organizations_err: # raise Exception(organizations_err) return organizations @@ -251,4 +255,6 @@ class EduVPN(object): return self.go_function(lib.GetSavedServersOLD) def get_saved_servers_new(self) -> str: - return self.go_function_custom_decode(lib.GetSavedServersNEW, decode_func=get_servers) + return self.go_function_custom_decode( + lib.GetSavedServersNEW, decode_func=get_servers + ) diff --git a/wrappers/python/src/server.py b/wrappers/python/src/server.py index b765ede..dce5d51 100644 --- a/wrappers/python/src/server.py +++ b/wrappers/python/src/server.py @@ -1,5 +1,6 @@ -from . import lib, cServers, cServerLocations -from ctypes import cast, POINTER +from . import lib, cServer, cServers, cServerLocations, cServerProfiles +from ctypes import cast, POINTER, c_char_p +from datetime import datetime class Profile: @@ -9,61 +10,85 @@ class Profile: self.default_gateway = default_gateway def __str__(self): - return f"Profile: {self.display_name}" + return self.display_name + + +class Profiles: + def __init__(self, profiles, current): + self.profiles = profiles + self.current_index = current + + @property + def current(self): + if self.current_index < len(self.profiles): + return self.profiles[self.current_index] + return None class Server: - def __init__(self, url, display_name, profiles, current_profile, expire_time): + def __init__(self, url, display_name, profiles=None, expire_time=0): self.url = url self.display_name = display_name self.profiles = profiles self.current_profile = None - if current_profile < len(profiles): - self.current_profile = profiles[current_profile] - self.expire_time = expire_time + self.expire_time = datetime.fromtimestamp(expire_time) def __str__(self): - return f"Server: {self.url}, with current profile: {self.current_profile}" + return self.display_name class InstituteServer(Server): - def __init__( - self, url, display_name, support_contact, profiles, current_profile, expire_time - ): - super().__init__(url, display_name, profiles, current_profile, expire_time) + def __init__(self, url, display_name, support_contact, profiles, expire_time): + super().__init__(url, display_name, profiles, expire_time) self.support_contact = support_contact - def __str__(self): - return f"Institute Server: {self.display_name}" - class SecureInternetServer(Server): def __init__( self, - url, + org_id, display_name, support_contact, profiles, - current_profile, expire_time, country_code, ): - super().__init__(url, display_name, profiles, current_profile, expire_time) + super().__init__(org_id, display_name, profiles, expire_time) + self.org_id = org_id self.support_contact = support_contact self.country_code = country_code - def __str__(self): - return f"Secure Internet Server: {self.display_name} with country {self.country_code}" - def get_type_for_str(type_str: str): - if type_str is "secure_internet": + if type_str == "secure_internet": return SecureInternetServer - if type_str is "custom_server": + if type_str == "custom_server": return Server return InstituteServer +def get_profiles(ptr): + if not ptr: + return [] + profiles = [] + _profiles = ptr.contents + current_profile = _profiles.current + if not _profiles.profiles: + return [] + for i in range(_profiles.total_profiles): + if not _profiles.profiles[i]: + continue + profile = _profiles.profiles[i].contents + profiles.append( + Profile( + profile.identifier.decode("utf-8"), + profile.display_name.decode("utf-8"), + profile.default_gateway == 1, + ) + ) + return Profiles(profiles, current_profile) + + def get_server(ptr, _type=None): if not ptr: return None @@ -79,31 +104,13 @@ def get_server(ptr, _type=None): support_contact = [] for i in range(current_server.total_support_contact): support_contact.append(current_server.support_contact[i].decode("utf-8")) - profiles = [] - if not current_server.profiles: - return None - - _profiles = current_server.profiles.contents - current_profile = _profiles.current - for i in range(_profiles.total_profiles): - if not _profiles.profiles or not _profiles.profiles[i]: - return None - profile = _profiles.profiles[i].contents - profiles.append( - Profile( - profile.identifier.decode("utf-8"), - profile.display_name.decode("utf-8"), - profile.default_gateway == 1, - ) - ) - + profiles = get_profiles(current_server.profiles) if _type is SecureInternetServer: return SecureInternetServer( identifier, display_name, support_contact, profiles, - current_profile, current_server.expire_time, current_server.country_code.decode("utf-8"), ) @@ -113,12 +120,21 @@ def get_server(ptr, _type=None): display_name, support_contact, profiles, - current_profile, current_server.expire_time, ) - return Server( - identifier, display_name, profiles, current_profile, current_server.expire_time - ) + return Server(identifier, display_name, profiles, current_server.expire_time) + + +def get_transition_server(ptr): + server = get_server(cast(ptr, POINTER(cServer))) + lib.FreeServer(ptr) + return server + + +def get_transition_profiles(ptr): + profiles = get_profiles(cast(ptr, POINTER(cServerProfiles))) + lib.FreeProfiles(ptr) + return profiles def get_servers(ptr): @@ -147,6 +163,7 @@ def get_servers(ptr): return returned return None + def get_locations(ptr): if ptr: locations = cast(ptr, POINTER(cServerLocations)).contents |
