From 7e7472c30e09eed15424494547729c1f93bc924e Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Mon, 20 Mar 2023 13:02:18 +0100 Subject: Exports: Implement initial V2 API The main change is that we now use JSON from types listed at the `types` package --- exports/disco.go | 217 ------------------ exports/exports.go | 643 ++++++++++++++++++++--------------------------------- exports/servers.go | 342 ---------------------------- 3 files changed, 236 insertions(+), 966 deletions(-) delete mode 100644 exports/disco.go delete mode 100644 exports/servers.go (limited to 'exports') diff --git a/exports/disco.go b/exports/disco.go deleted file mode 100644 index 08c921f..0000000 --- a/exports/disco.go +++ /dev/null @@ -1,217 +0,0 @@ -package main - -/* -// for free and size_t -#include -#include "error.h" - -typedef struct discoveryServer { - const char* authentication_url_template; - const char* base_url; - const char* country_code; - const char* display_name; - const char* keyword_list; - const char** public_key_list; - size_t total_public_keys; - const char* server_type; - const char** support_contact; - size_t total_support_contact; -} discoveryServer; - -typedef struct discoveryServers { - unsigned long long int version; - discoveryServer** servers; - size_t total_servers; -} discoveryServers; - -typedef struct discoveryOrganization { - const char* display_name; - const char* org_id; - const char* secure_internet_home; - const char* keyword_list; -} discoveryOrganization; - -typedef struct discoveryOrganizations { - unsigned long long int version; - discoveryOrganization** organizations; - size_t total_organizations; -} discoveryOrganizations; -*/ -import "C" - -import ( - "unsafe" - - "github.com/eduvpn/eduvpn-common/client" - "github.com/eduvpn/eduvpn-common/types" -) - -func getCPtrDiscoOrganization( - state *client.Client, - organization *types.DiscoveryOrganization, -) *C.discoveryOrganization { - returnedStruct := (*C.discoveryOrganization)( - C.malloc(C.size_t(unsafe.Sizeof(C.discoveryOrganization{}))), - ) - returnedStruct.display_name = C.CString(state.GetTranslated(organization.DisplayName)) - returnedStruct.org_id = C.CString(organization.OrgID) - returnedStruct.secure_internet_home = C.CString(organization.SecureInternetHome) - returnedStruct.keyword_list = C.CString(state.GetTranslated(organization.KeywordList)) - return returnedStruct -} - -func getCPtrDiscoOrganizations( - state *client.Client, - organizations *types.DiscoveryOrganizations, -) (C.size_t, **C.discoveryOrganization) { - totalOrganizations := C.size_t(len(organizations.List)) - if totalOrganizations > 0 { - organizationsPtr := (**C.discoveryOrganization)( - C.malloc(totalOrganizations * C.size_t(unsafe.Sizeof(uintptr(0)))), - ) - cOrganizations := unsafe.Slice(organizationsPtr, totalOrganizations) - index := 0 - for _, organization := range organizations.List { - cOrganization := getCPtrDiscoOrganization(state, &organization) - cOrganizations[index] = cOrganization - index++ - } - return totalOrganizations, organizationsPtr - } - return 0, nil -} - -func getCPtrDiscoServer( - state *client.Client, - server *types.DiscoveryServer, -) *C.discoveryServer { - returnedStruct := (*C.discoveryServer)( - C.malloc(C.size_t(unsafe.Sizeof(C.discoveryServer{}))), - ) - returnedStruct.authentication_url_template = C.CString(server.AuthenticationURLTemplate) - returnedStruct.base_url = C.CString(server.BaseURL) - returnedStruct.country_code = C.CString(server.CountryCode) - returnedStruct.display_name = C.CString(state.GetTranslated(server.DisplayName)) - returnedStruct.keyword_list = C.CString(state.GetTranslated(server.KeywordList)) - returnedStruct.total_public_keys, returnedStruct.public_key_list = getCPtrListStrings( - server.PublicKeyList, - ) - returnedStruct.server_type = C.CString(server.Type) - returnedStruct.total_support_contact, returnedStruct.support_contact = getCPtrListStrings( - server.SupportContact, - ) - return returnedStruct -} - -func getCPtrDiscoServers( - state *client.Client, - servers *types.DiscoveryServers, -) (C.size_t, **C.discoveryServer) { - totalServers := C.size_t(len(servers.List)) - if totalServers > 0 { - serversPtr := (**C.discoveryServer)( - C.malloc(totalServers * C.size_t(unsafe.Sizeof(uintptr(0)))), - ) - cServers := unsafe.Slice(serversPtr, totalServers) - index := 0 - for _, server := range servers.List { - cServer := getCPtrDiscoServer(state, &server) - cServers[index] = cServer - index++ - } - return totalServers, serversPtr - } - return 0, nil -} - -func freeDiscoOrganization(cOrganization *C.discoveryOrganization) { - C.free(unsafe.Pointer(cOrganization.display_name)) - C.free(unsafe.Pointer(cOrganization.org_id)) - C.free(unsafe.Pointer(cOrganization.secure_internet_home)) - C.free(unsafe.Pointer(cOrganization.keyword_list)) - C.free(unsafe.Pointer(cOrganization)) -} - -func freeDiscoServer(cServer *C.discoveryServer) { - C.free(unsafe.Pointer(cServer.authentication_url_template)) - C.free(unsafe.Pointer(cServer.base_url)) - C.free(unsafe.Pointer(cServer.country_code)) - C.free(unsafe.Pointer(cServer.display_name)) - C.free(unsafe.Pointer(cServer.keyword_list)) - freeCListStrings(cServer.public_key_list, cServer.total_public_keys) - C.free(unsafe.Pointer(cServer.server_type)) - freeCListStrings(cServer.support_contact, cServer.total_support_contact) - C.free(unsafe.Pointer(cServer)) -} - -//export FreeDiscoServers -func FreeDiscoServers(cServers *C.discoveryServers) { - if cServers.total_servers > 0 { - servers := unsafe.Slice(cServers.servers, cServers.total_servers) - for i := C.size_t(0); i < cServers.total_servers; i++ { - freeDiscoServer(servers[i]) - } - C.free(unsafe.Pointer(cServers.servers)) - } - C.free(unsafe.Pointer(cServers)) -} - -//export FreeDiscoOrganizations -func FreeDiscoOrganizations(cOrganizations *C.discoveryOrganizations) { - if cOrganizations.total_organizations > 0 { - organizations := unsafe.Slice(cOrganizations.organizations, cOrganizations.total_organizations) - for i := C.size_t(0); i < cOrganizations.total_organizations; i++ { - freeDiscoOrganization(organizations[i]) - } - C.free(unsafe.Pointer(cOrganizations.organizations)) - } - C.free(unsafe.Pointer(cOrganizations)) -} - -//export GetDiscoServers -func GetDiscoServers(name *C.char) (*C.discoveryServers, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return nil, getError(stateErr) - } - servers, serversErr := state.DiscoServers() - // if we get no servers then we return immediately - if servers == nil { - return nil, getError(serversErr) - } - returnedStruct := (*C.discoveryServers)( - C.malloc(C.size_t(unsafe.Sizeof(C.discoveryServers{}))), - ) - returnedStruct.version = C.ulonglong(servers.Version) - returnedStruct.total_servers, returnedStruct.servers = getCPtrDiscoServers( - state, - servers, - ) - return returnedStruct, getError(serversErr) -} - -//export GetDiscoOrganizations -func GetDiscoOrganizations(name *C.char) (*C.discoveryOrganizations, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return nil, getError(stateErr) - } - organizations, organizationsErr := state.DiscoOrganizations() - // if we get no organizations then we return immediately - if organizations == nil { - return nil, getError(organizationsErr) - } - returnedStruct := (*C.discoveryOrganizations)( - C.malloc(C.size_t(unsafe.Sizeof(C.discoveryOrganizations{}))), - ) - - returnedStruct.version = C.ulonglong(organizations.Version) - returnedStruct.total_organizations, returnedStruct.organizations = getCPtrDiscoOrganizations( - state, - organizations, - ) - - return returnedStruct, getError(organizationsErr) -} diff --git a/exports/exports.go b/exports/exports.go index c09e980..d8fd6ea 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -6,114 +6,87 @@ package main #include "server.h" typedef long long int (*ReadRxBytes)(); -typedef struct token { - const char* access; - const char* refresh; - unsigned long long int expired; -} token; -typedef void (*UpdateToken)(const char* name, server* srv, token* tok); - -typedef struct configData { - const char* config; - const char* config_type; - token* tokens; -} configData; - -typedef int (*PythonCB)(const char* name, int oldstate, int newstate, void* data); +typedef int (*StateCB)(int oldstate, int newstate, void* data); static long long int get_read_rx_bytes(ReadRxBytes read) { return read(); } - -static void update_token(UpdateToken func, const char* name, server* srv, token* tok) -{ - func(name, srv, tok); -} - -static int call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data) +static int call_callback(StateCB callback, int oldstate, int newstate, void* data) { - return callback(name, oldstate, newstate, data); + return callback(oldstate, newstate, data); } */ import "C" import ( - "time" + "encoding/json" "unsafe" - "github.com/eduvpn/eduvpn-common/internal/log" - "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/internal/server" "github.com/go-errors/errors" "github.com/eduvpn/eduvpn-common/client" + "github.com/eduvpn/eduvpn-common/types" ) -var PStateCallbacks map[string]C.PythonCB +var ( + PStateCallback C.StateCB + VPNState *client.Client +) -var VPNStates map[string]*client.Client +func getTokens(tokens *C.char) (t types.Tokens, err error) { + err = json.Unmarshal([]byte(C.GoString(tokens)), &t) + return t, err +} -func GetStateData( - state *client.Client, - stateID client.FSMStateID, - data interface{}, -) unsafe.Pointer { - switch stateID { - case client.StateNoServer: - return (unsafe.Pointer)(getTransitionDataServers(state, data)) - case client.StateOAuthStarted: - if converted, ok := data.(string); ok { - return (unsafe.Pointer)(C.CString(converted)) - } - case client.StateAskLocation: - return (unsafe.Pointer)(getTransitionSecureLocations(data)) - case client.StateAskProfile: - return (unsafe.Pointer)(getTransitionProfiles(data)) - case client.StateDisconnected: - return (unsafe.Pointer)(getTransitionServer(state, data)) - case client.StateDisconnecting: - return (unsafe.Pointer)(getTransitionServer(state, data)) - case client.StateConnecting: - return (unsafe.Pointer)(getTransitionServer(state, data)) - case client.StateConnected: - return (unsafe.Pointer)(getTransitionServer(state, data)) - default: - return nil +func getCError(err error) *C.char { + if err == nil { + return C.CString("") } - return nil + return C.CString(err.Error()) +} + +func getReturnData(data interface{}) (string, error) { + // If it is already a string return directly + if x, ok := data.(string); ok { + return x, nil + } + + // Otherwise use JSON + b, err := json.Marshal(data) + if err != nil { + return "", err + } + return string(b), nil } func StateCallback( state *client.Client, - name string, oldState client.FSMStateID, newState client.FSMStateID, data interface{}, ) bool { - PStateCallback, exists := PStateCallbacks[name] - if !exists || PStateCallback == nil { + if PStateCallback == nil { return false } - nameC := C.CString(name) oldStateC := C.int(oldState) newStateC := C.int(newState) - dataC := GetStateData(state, newState, data) - handled := C.call_callback(PStateCallback, nameC, oldStateC, newStateC, dataC) - C.free(unsafe.Pointer(nameC)) - // data_c gets freed by the wrapper + d, err := getReturnData(data) + if err != nil { + return false + } + dataC := C.CString(d) + handled := C.call_callback(PStateCallback, oldStateC, newStateC, unsafe.Pointer(dataC)) + FreeString(dataC) return handled == C.int(1) } -func GetVPNState(name string) (*client.Client, error) { - state, exists := VPNStates[name] - - if !exists || state == nil { - return nil, errors.Errorf("state with name %s not found", name) +func getVPNState() (*client.Client, error) { + if VPNState == nil { + return nil, errors.New("No state available, did you register the client?") } - - return state, nil + return VPNState, nil } //export Register @@ -121,39 +94,44 @@ func Register( name *C.char, version *C.char, configDirectory *C.char, - language *C.char, - stateCallback C.PythonCB, + stateCallback C.StateCB, debug C.int, -) *C.error { - nameStr := C.GoString(name) - versionStr := C.GoString(version) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - state = &client.Client{} - } - if VPNStates == nil { - VPNStates = make(map[string]*client.Client) - } - if PStateCallbacks == nil { - PStateCallbacks = make(map[string]C.PythonCB) - } - VPNStates[nameStr] = state - PStateCallbacks[nameStr] = stateCallback +) *C.char { + state, stateErr := getVPNState() + if stateErr == nil { + return getCError(errors.New("failed to register, a VPN state is already present")) + } + state = &client.Client{} + VPNState = state + PStateCallback = stateCallback registerErr := state.Register( - nameStr, - versionStr, + C.GoString(name), + C.GoString(version), C.GoString(configDirectory), - C.GoString(language), func(old client.FSMStateID, new client.FSMStateID, data interface{}) bool { - return StateCallback(state, nameStr, old, new, data) + return StateCallback(state, old, new, data) }, - debug != 0, + debug == 1, ) - if registerErr != nil { - delete(VPNStates, nameStr) + return getCError(registerErr) +} + +//export ExpiryTimes +func ExpiryTimes() (*C.char, *C.char) { + state, stateErr := getVPNState() + if stateErr != nil { + return nil, getCError(stateErr) + } + exp, err := state.ExpiryTimes() + if err != nil { + return nil, getCError(err) + } + ret, err := getReturnData(exp) + if err != nil { + return nil, getCError(err) } - return getError(registerErr) + return C.CString(ret), nil } //export SetTokenUpdater @@ -179,362 +157,216 @@ func SetTokenUpdater(name *C.char, updater C.UpdateToken) *C.error { } //export Deregister -func Deregister(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func Deregister() *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } state.Deregister() + VPNState = nil return nil } -func getError(err error) *C.error { - if err == nil { - return nil - } - errorStruct := (*C.error)( - C.malloc(C.size_t(unsafe.Sizeof(C.error{}))), - ) - if err1, ok := err.(*errors.Error); ok { - if err1 == nil { - errorStruct.traceback = C.CString("N/A") - errorStruct.cause = C.CString("unknown error") - return errorStruct - } - errorStruct.traceback = C.CString(err1.ErrorStack()) - if err1.Err == nil { - errorStruct.cause = C.CString(err1.Error()) - } else { - errorStruct.cause = C.CString(err1.Err.Error()) - } - } else { - errorStruct.traceback = C.CString("N/A") - errorStruct.cause = C.CString(err.Error()) - } - return errorStruct -} - -//export FreeError -func FreeError(err *C.error) { - C.free(unsafe.Pointer(err.traceback)) - C.free(unsafe.Pointer(err.cause)) - C.free(unsafe.Pointer(err)) -} - //export CancelOAuth -func CancelOAuth(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func CancelOAuth() *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } cancelErr := state.CancelOAuth() - return getError(cancelErr) + return getCError(cancelErr) } -//export RemoveSecureInternet -func RemoveSecureInternet(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export AddServer +func AddServer(_type *C.char, id *C.char) *C.char { + // TODO: type + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) - } - removeErr := state.RemoveSecureInternet() - return getError(removeErr) -} - -//export AddInstituteAccess -func AddInstituteAccess(name *C.char, url *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) + } + t := C.GoString(_type) + var err error + switch t { + case "institute_access": + err = state.AddInstituteServer(C.GoString(id)) + case "secure_internet": + err = state.AddSecureInternetHomeServer(C.GoString(id)) + case "custom_server": + err = state.AddCustomServer(C.GoString(id)) + default: + err = errors.Errorf("invalid type: %v", t) } - // FIXME: Return server result - _, addErr := state.AddInstituteServer(C.GoString(url)) - return getError(addErr) + return getCError(err) } -//export AddSecureInternetHomeServer -func AddSecureInternetHomeServer(name *C.char, orgID *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export RemoveServer +func RemoveServer(_type *C.char, id *C.char) *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) - } - // FIXME: Return server result - _, addErr := state.AddSecureInternetHomeServer(C.GoString(orgID)) - return getError(addErr) -} - -//export AddCustomServer -func AddCustomServer(name *C.char, url *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) + } + t := C.GoString(_type) + var err error + switch t { + case "institute_access": + err = state.RemoveInstituteAccess(C.GoString(id)) + case "secure_internet": + err = state.RemoveSecureInternet() + case "custom_server": + err = state.RemoveCustomServer(C.GoString(id)) + default: + err = errors.Errorf("invalid type: %v", t) } - // FIXME: Return server result - _, addErr := state.AddCustomServer(C.GoString(url)) - return getError(addErr) + return getCError(err) } -//export RemoveInstituteAccess -func RemoveInstituteAccess(name *C.char, url *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export CurrentServer +func CurrentServer() (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return nil, getCError(stateErr) } - removeErr := state.RemoveInstituteAccess(C.GoString(url)) - return getError(removeErr) -} - -//export RemoveCustomServer -func RemoveCustomServer(name *C.char, url *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + srv, err := state.CurrentServer() + if err != nil { + return nil, getCError(err) } - removeErr := state.RemoveCustomServer(C.GoString(url)) - return getError(removeErr) -} - -func cToken(t oauth.Token) *C.token { - cTok := (*C.token)(C.malloc(C.size_t(unsafe.Sizeof(C.token{})))) - cTok.access = C.CString(t.Access) - cTok.refresh = C.CString(t.Refresh) - cTok.expired = C.ulonglong(t.ExpiredTimestamp.Unix()) - return cTok -} - -func cConfig(config *client.ConfigData) *C.configData { - // No config so return nil pointer - if config == nil { - return nil + ret, err := getReturnData(srv) + if err != nil { + return nil, getCError(err) } - cConf := (*C.configData)(C.malloc(C.size_t(unsafe.Sizeof(C.configData{})))) - cConf.config = C.CString(config.Config) - cConf.config_type = C.CString(config.Type) - cConf.tokens = cToken(config.Tokens) - return cConf + return C.CString(ret), nil } -//export FreeTokens -func FreeTokens(tokens *C.token) { - C.free(unsafe.Pointer(tokens.access)) - C.free(unsafe.Pointer(tokens.refresh)) - C.free(unsafe.Pointer(tokens)) -} - -//export FreeConfig -func FreeConfig(config *C.configData) { - C.free(unsafe.Pointer(config.config)) - C.free(unsafe.Pointer(config.config_type)) - C.free(unsafe.Pointer(config.tokens.access)) - C.free(unsafe.Pointer(config.tokens.refresh)) - C.free(unsafe.Pointer(config.tokens)) - C.free(unsafe.Pointer(config)) -} - -//export GetConfigSecureInternet -func GetConfigSecureInternet( - name *C.char, - orgID *C.char, - preferTCP C.int, - prevTokens C.token, -) (*C.configData, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export ServerList +func ServerList() (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return nil, getError(stateErr) - } - preferTCPBool := preferTCP == 1 - t := oauth.Token{ - Access: C.GoString(prevTokens.access), - Refresh: C.GoString(prevTokens.refresh), - ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0), + return nil, getCError(stateErr) } - cfg, configErr := state.GetConfigSecureInternet(C.GoString(orgID), preferTCPBool, t) - return cConfig(cfg), getError(configErr) -} - -//export GetConfigInstituteAccess -func GetConfigInstituteAccess( - name *C.char, - url *C.char, - preferTCP C.int, - prevTokens C.token, -) (*C.configData, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return nil, getError(stateErr) + list, err := state.ServerList() + if err != nil { + return nil, getCError(err) } - preferTCPBool := preferTCP == 1 - t := oauth.Token{ - Access: C.GoString(prevTokens.access), - Refresh: C.GoString(prevTokens.refresh), - ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0), + ret, err := getReturnData(list) + if err != nil { + return nil, getCError(err) } - cfg, configErr := state.GetConfigInstituteAccess(C.GoString(url), preferTCPBool, t) - return cConfig(cfg), getError(configErr) + return C.CString(ret), nil } -//export GetConfigCustomServer -func GetConfigCustomServer( - name *C.char, - url *C.char, - preferTCP C.int, - prevTokens C.token, -) (*C.configData, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export GetConfig +func GetConfig(_type *C.char, id *C.char, pTCP C.int, tokens *C.char) (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return nil, getError(stateErr) + return nil, getCError(stateErr) + } + preferTCPBool := pTCP == 1 + tok, err := getTokens(tokens) + if err != nil { + return nil, getCError(err) + } + t := C.GoString(_type) + var cfg *types.Configuration + switch t { + case "institute_access": + cfg, err = state.GetConfigInstituteAccess(C.GoString(id), preferTCPBool, tok) + case "secure_internet": + cfg, err = state.GetConfigSecureInternet(C.GoString(id), preferTCPBool, tok) + case "custom_server": + cfg, err = state.GetConfigCustomServer(C.GoString(id), preferTCPBool, tok) + default: + err = errors.Errorf("invalid type: %v", t) } - preferTCPBool := preferTCP == 1 - t := oauth.Token{ - Access: C.GoString(prevTokens.access), - Refresh: C.GoString(prevTokens.refresh), - ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0), + if cfg != nil && err == nil { + d, err := getReturnData(cfg) + if err == nil { + return C.CString(d), nil + } } - cfg, configErr := state.GetConfigCustomServer(C.GoString(url), preferTCPBool, t) - return cConfig(cfg), getError(configErr) + return nil, getCError(err) } //export SetProfileID -func SetProfileID(name *C.char, data *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func SetProfileID(data *C.char) *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } profileErr := state.SetProfileID(C.GoString(data)) - return getError(profileErr) -} - -//export ChangeSecureLocation -func ChangeSecureLocation(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) - } - locationErr := state.ChangeSecureLocation() - return getError(locationErr) + return getCError(profileErr) } //export SetSecureLocation -func SetSecureLocation(name *C.char, data *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func SetSecureLocation(data *C.char) *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } locationErr := state.SetSecureLocation(C.GoString(data)) - return getError(locationErr) -} - -//export GoBack -func GoBack(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) - } - goBackErr := state.GoBack() - return getError(goBackErr) + return getCError(locationErr) } -//export SetSearchServer -func SetSearchServer(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export DiscoServers +func DiscoServers() (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return nil, getCError(stateErr) } - setSearchErr := state.SetSearchServer() - return getError(setSearchErr) -} - -//export Cleanup -func Cleanup(name *C.char, prevTokens C.token) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + servers, err := state.DiscoServers() + if servers == nil && err != nil { + return nil, getCError(err) } - t := oauth.Token{ - Access: C.GoString(prevTokens.access), - Refresh: C.GoString(prevTokens.refresh), - ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0), + s, reterr := getReturnData(servers) + if reterr != nil { + return nil, getCError(reterr) } - err := state.Cleanup(t) - return getError(err) + return C.CString(s), getCError(err) } -//export SetDisconnected -func SetDisconnected(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export DiscoOrganizations +func DiscoOrganizations() (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return nil, getCError(stateErr) } - setDisconnectedErr := state.SetDisconnected() - return getError(setDisconnectedErr) -} - -//export SetDisconnecting -func SetDisconnecting(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + orgs, err := state.DiscoOrganizations() + if orgs == nil && err != nil { + return nil, getCError(err) } - setDisconnectingErr := state.SetDisconnecting() - return getError(setDisconnectingErr) -} - -//export SetConnecting -func SetConnecting(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + s, reterr := getReturnData(orgs) + if reterr != nil { + return nil, getCError(reterr) } - setConnectingErr := state.SetConnecting() - return getError(setConnectingErr) + return C.CString(s), getCError(err) } -//export SetConnected -func SetConnected(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export Cleanup +func Cleanup(prevTokens *C.char) *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } - setConnectedErr := state.SetConnected() - return getError(setConnectedErr) + t, err := getTokens(prevTokens) + if err != nil { + return getCError(err) + } + err = state.Cleanup(t) + return getCError(err) } //export RenewSession -func RenewSession(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func RenewSession() *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } renewSessionErr := state.RenewSession() - return getError(renewSessionErr) + return getCError(renewSessionErr) } //export ShouldRenewButton -func ShouldRenewButton(name *C.char) C.int { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func ShouldRenewButton() C.int { + state, stateErr := getVPNState() if stateErr != nil { return C.int(0) } @@ -545,47 +377,45 @@ func ShouldRenewButton(name *C.char) C.int { return C.int(0) } -//export InFSMState -func InFSMState(name *C.char, checkState C.int) C.int { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export SetSupportWireguard +func SetSupportWireguard(support C.int) *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return C.int(0) + return getCError(stateErr) } - inStateBool := state.InFSMState(client.FSMStateID(checkState)) - if inStateBool { - return C.int(1) - } - return C.int(0) + state.SupportsWireguard = support == 1 + return nil } -//export SetSupportWireguard -func SetSupportWireguard(name *C.char, support C.int) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export SecureLocationList +func SecureLocationList() (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return nil, getCError(stateErr) } - state.SupportsWireguard = support == 1 - return nil + locs := state.Discovery.SecureLocationList() + l, err := getReturnData(locs) + if err != nil { + return nil, getCError(err) + } + return C.CString(l), nil } //export StartFailover -func StartFailover(name *C.char, gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func StartFailover(gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return C.int(0), getError(stateErr) + return C.int(0), getCError(stateErr) } dropped, droppedErr := state.StartFailover(C.GoString(gateway), int(mtu), func() (int64, error) { rxBytes := int64(C.get_read_rx_bytes(readRxBytes)) - if rxBytes == -1 { + if rxBytes < 0 { return 0, errors.New("client gave an invalid rx bytes value") } return rxBytes, nil }) if droppedErr != nil { - return C.int(0), getError(droppedErr) + return C.int(0), getCError(droppedErr) } droppedC := C.int(0) if dropped { @@ -595,15 +425,14 @@ func StartFailover(name *C.char, gateway *C.char, mtu C.int, readRxBytes C.ReadR } //export CancelFailover -func CancelFailover(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func CancelFailover() *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } cancelErr := state.CancelFailover() if cancelErr != nil { - return getError(cancelErr) + return getCError(cancelErr) } return nil } diff --git a/exports/servers.go b/exports/servers.go deleted file mode 100644 index 73b8b6c..0000000 --- a/exports/servers.go +++ /dev/null @@ -1,342 +0,0 @@ -package main - -/* -// for free and size_t -#include -#include "error.h" -#include "server.h" -*/ -import "C" - -import ( - "unsafe" - - "github.com/eduvpn/eduvpn-common/client" - "github.com/eduvpn/eduvpn-common/internal/server" -) - -// Get the pointer to the C struct for the profile -// We allocate the struct, the profile ID and the display name -func getCPtrProfile(profile *server.Profile) *C.serverProfile { - // Allocate the struct using malloc and the size of the struct - cProfile := (*C.serverProfile)(C.malloc(C.size_t(unsafe.Sizeof(C.serverProfile{})))) - cProfile.id = C.CString(profile.ID) - cProfile.display_name = C.CString(profile.DisplayName) - if profile.DefaultGateway { - cProfile.default_gateway = C.int(1) - } else { - cProfile.default_gateway = C.int(0) - } - - return cProfile -} - -// Get the pointer to the C struct for the profiles -// We allocate the struct and the struct inside it for the profiles -func getCPtrProfiles(serverProfiles *server.ProfileInfo) *C.serverProfiles { - goProfiles := serverProfiles.Info.ProfileList - // Allocate the profles struct using malloc and the size of a pointer - cProfiles := (*C.serverProfiles)(C.malloc(C.size_t(uintptr(0)))) - totalProfiles := C.size_t(len(goProfiles)) - // Defaults if we have no profiles - cProfiles.current = C.int(0) - cProfiles.profiles = nil - cProfiles.total_profiles = totalProfiles - // If we have profiles (which we should), we allocate the struct with malloc and the size of a pointer - // We then fill the struct by converting it to a go slice and get a C pointer for each profile - if totalProfiles > 0 { - profilesPtr := (**C.serverProfile)(C.malloc(totalProfiles * C.size_t(unsafe.Sizeof(uintptr(0))))) - profiles := unsafe.Slice(profilesPtr, totalProfiles) - index := 0 - for _, profile := range goProfiles { - profiles[index] = getCPtrProfile(&profile) - index++ - } - cProfiles.current = C.int(serverProfiles.CurrentProfileIndex()) - cProfiles.profiles = (**C.serverProfile)(profilesPtr) - } - return cProfiles -} - -// Free the profiles by looping through them if there are any -// Also free the pointer itself -// -//export FreeProfiles -func FreeProfiles(profiles *C.serverProfiles) { - // We should only free the profiles if we have them (which we should) - if profiles.total_profiles > 0 { - // Convert it to a go slice - profilesSlice := unsafe.Slice(profiles.profiles, profiles.total_profiles) - // Loop through the pointers and free th allocated strings and the struct itself - for i := C.size_t(0); i < profiles.total_profiles; i++ { - C.free(unsafe.Pointer(profilesSlice[i].id)) - C.free(unsafe.Pointer(profilesSlice[i].display_name)) - C.free(unsafe.Pointer(profilesSlice[i])) - } - // Free the inner profiles struct - C.free(unsafe.Pointer(profiles.profiles)) - } - // Free the profiles struct itself - C.free(unsafe.Pointer(profiles)) -} - -// Get a list of strings with a size as a c structure -// Returns the size in size_t and the list of strings as a double pointer char -func getCPtrListStrings(allStrings []string) (C.size_t, **C.char) { - // Get the total strings in size_t - totalStrings := C.size_t(len(allStrings)) - - // If we have strings - // Allocate memory for the strings array - if totalStrings > 0 { - stringsPtr := (**C.char)(C.malloc(totalStrings * C.size_t(unsafe.Sizeof(uintptr(0))))) - // Go slice conversion - cStrings := unsafe.Slice(stringsPtr, totalStrings) - - // Loop through and allocate the string for each contact - for index, string := range allStrings { - cStrings[index] = C.CString(string) - } - return totalStrings, (**C.char)(stringsPtr) - } - - // No strings then the length is zero and the char array is nil - return C.size_t(0), nil -} - -// Function for freeing an array/list of strings -// It takes the strings as a pointer to a string and the total strings in size_t -func freeCListStrings(allStrings **C.char, totalStrings C.size_t) { - // If we have strings we should free them - // By converting to a Go slice, and freeing them ony by one - // At last free the pointer itself - if totalStrings > 0 { - stringsSlice := unsafe.Slice(allStrings, totalStrings) - for i := C.size_t(0); i < totalStrings; i++ { - C.free(unsafe.Pointer(stringsSlice[i])) - } - C.free(unsafe.Pointer(allStrings)) - } -} - -// Function for getting the server, -// It gets the main state as a pointer as we need to convert some string maps to localized strings -// It gets the base information for a server as well -func getCPtrServer(state *client.Client, base *client.ServerBase) *C.server { - // Allocation using malloc and the size of the struct - cServer := (*C.server)(C.malloc(C.size_t(unsafe.Sizeof(C.server{})))) - // String allocation and translate the display name - identifier := base.URL - countryCode := "" - // A secure internet server has multiple locations - locations := []string{} - if base.Type == "secure_internet" { - identifier = state.Servers.SecureInternetHomeServer.HomeOrganizationID - countryCode = state.Servers.SecureInternetHomeServer.CurrentLocation - locations = state.Discovery.SecureLocationList() - } - - cServer.identifier = C.CString(identifier) - cServer.display_name = C.CString(state.GetTranslated(base.DisplayName)) - cServer.country_code = C.CString(countryCode) - cServer.server_type = C.CString(base.Type) - // Call the helper to get the list of support contacts - cServer.total_support_contact, cServer.support_contact = getCPtrListStrings( - base.SupportContact, - ) - locationsStruct := (*C.serverLocations)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{})))) - locationsStruct.total_locations, locationsStruct.locations = getCPtrListStrings(locations) - cServer.locations = locationsStruct - - profiles := base.ValidProfiles(state.SupportsWireguard) - cServer.profiles = getCPtrProfiles(&profiles) - // No endtime is given if we get servers when it has been partially initialised - if base.EndTime.IsZero() { - cServer.expire_time = C.ulonglong(0) - } else { - // The expire time should be stored as an unsigned long long in unix time - cServer.expire_time = C.ulonglong(base.EndTime.Unix()) - } - return cServer -} - -// Function for freeing a single server -// Gets the pointer to C struct -// -//export FreeServer -func FreeServer(info *C.server) { - // Free strings - C.free(unsafe.Pointer(info.identifier)) - C.free(unsafe.Pointer(info.display_name)) - C.free(unsafe.Pointer(info.country_code)) - C.free(unsafe.Pointer(info.server_type)) - - // Free arrays - freeCListStrings(info.support_contact, info.total_support_contact) - FreeSecureLocations(info.locations) - FreeProfiles(info.profiles) - - // Free the struct itself - C.free(unsafe.Pointer(info)) -} - -// Get the C ptr to the servers, returns the length in size_t and the double pointer to the struct -func getCPtrServers( - state *client.Client, - serverMap map[string]*server.InstituteAccessServer, -) (C.size_t, **C.server) { - totalServers := C.size_t(len(serverMap)) - // If we have servers, which is not always the case - if totalServers > 0 { - serversPtr := (**C.server)(C.malloc(totalServers * C.size_t(unsafe.Sizeof(uintptr(0))))) - servers := unsafe.Slice(serversPtr, totalServers) - index := 0 - for _, currentServer := range serverMap { - cServer := getCPtrServer(state, ¤tServer.Basic) - servers[index] = cServer - index++ - } - return totalServers, serversPtr - } - return C.size_t(0), nil -} - -// This function takes the servers as a C struct pointer as input -// It frees all allocated memory for the server -// -//export FreeServers -func FreeServers(cServers *C.servers) { - // Free the custom servers if there are any - if cServers.total_custom > 0 { - customServers := unsafe.Slice(cServers.custom_servers, cServers.total_custom) - for i := C.size_t(0); i < cServers.total_custom; i++ { - FreeServer(customServers[i]) - } - C.free(unsafe.Pointer(cServers.custom_servers)) - } - // Free the institute access servers if there are any - if cServers.total_institute > 0 { - instituteServers := unsafe.Slice(cServers.institute_servers, cServers.total_institute) - - for i := C.size_t(0); i < cServers.total_institute; i++ { - FreeServer(instituteServers[i]) - } - C.free(unsafe.Pointer(cServers.institute_servers)) - } - // Free the secure internet server if there is one - if cServers.secure_internet_server != nil { - FreeServer(cServers.secure_internet_server) - } - // Free the structure itself - C.free(unsafe.Pointer(cServers)) -} - -// Return the servers as a C struct pointer -// It takes the state as a pointer as we need to translate some strings -// It also takes the servers as a pointer that belongs to the main state or gathered from the callback -func getSavedServersWithOptions(state *client.Client, servers *server.Servers) *C.servers { - // Allocate the struct that we will return - // With the size of the c struct - returnedStruct := (*C.servers)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{})))) - - // Get the different categories of servers - totalCustom, customPtr := getCPtrServers(state, servers.CustomServers.Map) - totalInstitute, institutePtr := getCPtrServers(state, servers.InstituteServers.Map) - var secureServerPtr *C.server - secureInternetBase, secureInternetBaseErr := servers.SecureInternetHomeServer.Base() - if secureInternetBaseErr == nil && secureInternetBase != nil { - // FIXME: log error? - secureServerPtr = getCPtrServer(state, secureInternetBase) - // Give a new identifier - C.free(unsafe.Pointer(secureServerPtr.identifier)) - secureServerPtr.identifier = C.CString(servers.SecureInternetHomeServer.HomeOrganizationID) - secureServerPtr.country_code = C.CString(servers.SecureInternetHomeServer.CurrentLocation) - } - - // Fill the struct and return - returnedStruct.custom_servers = customPtr - returnedStruct.total_custom = totalCustom - returnedStruct.institute_servers = institutePtr - returnedStruct.total_institute = totalInstitute - returnedStruct.secure_internet_server = secureServerPtr - return returnedStruct -} - -// This function takes the name as input which is the name of the client -// It gets the state by name and then returns the saved servers as a c struct belonging to it -// -//export GetSavedServers -func GetSavedServers(name *C.char) (*C.servers, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return nil, getError(stateErr) - } - servers := getSavedServersWithOptions(state, &state.Servers) - return servers, nil -} - -// This function takes the name as input which is the name of the client -// It gets the state by name and then returns the current server as a c struct belonging to it -// -//export GetCurrentServer -func GetCurrentServer(name *C.char) (*C.server, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return nil, getError(stateErr) - } - server, serverErr := state.Servers.GetCurrentServer() - if serverErr != nil { - return nil, getError(serverErr) - } - base, baseErr := server.Base() - if baseErr != nil { - return nil, getError(baseErr) - } - cServer := getCPtrServer(state, base) - return cServer, nil -} - -// This function takes the state as input which is the main state -// It also takes the data as an interface and if it has the servers type gets the data as a c struct otherwise nil -func getTransitionDataServers(state *client.Client, data interface{}) *C.servers { - if converted, ok := data.(server.Servers); ok { - return getSavedServersWithOptions(state, &converted) - } - return nil -} - -//export FreeSecureLocations -func FreeSecureLocations(locations *C.serverLocations) { - freeCListStrings(locations.locations, locations.total_locations) - C.free(unsafe.Pointer(locations)) -} - -func getTransitionSecureLocations(data interface{}) *C.serverLocations { - if locations, ok := data.([]string); ok { - returnedStruct := (*C.serverLocations)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{})))) - returnedStruct.total_locations, returnedStruct.locations = getCPtrListStrings(locations) - return returnedStruct - } - return nil -} - -func getTransitionProfiles(data interface{}) *C.serverProfiles { - if profiles, ok := data.(*server.ProfileInfo); ok { - return getCPtrProfiles(profiles) - } - return nil -} - -func getTransitionServer(state *client.Client, data interface{}) *C.server { - if server, ok := data.(server.Server); ok { - base, baseErr := server.Base() - if baseErr != nil { - // TODO: LOG - return nil - } - return getCPtrServer(state, base) - } - return nil -} -- cgit v1.2.3