summaryrefslogtreecommitdiff
path: root/exports
diff options
context:
space:
mode:
Diffstat (limited to 'exports')
-rw-r--r--exports/disco.go217
-rw-r--r--exports/exports.go643
-rw-r--r--exports/servers.go342
3 files changed, 236 insertions, 966 deletions
diff --git a/exports/disco.go b/exports/disco.go
deleted file mode 100644
index 08c921f..0000000
--- a/exports/disco.go
+++ /dev/null
@@ -1,217 +0,0 @@
-package main
-
-/*
-// for free and size_t
-#include <stdlib.h>
-#include "error.h"
-
-typedef struct discoveryServer {
- const char* authentication_url_template;
- const char* base_url;
- const char* country_code;
- const char* display_name;
- const char* keyword_list;
- const char** public_key_list;
- size_t total_public_keys;
- const char* server_type;
- const char** support_contact;
- size_t total_support_contact;
-} discoveryServer;
-
-typedef struct discoveryServers {
- unsigned long long int version;
- discoveryServer** servers;
- size_t total_servers;
-} discoveryServers;
-
-typedef struct discoveryOrganization {
- const char* display_name;
- const char* org_id;
- const char* secure_internet_home;
- const char* keyword_list;
-} discoveryOrganization;
-
-typedef struct discoveryOrganizations {
- unsigned long long int version;
- discoveryOrganization** organizations;
- size_t total_organizations;
-} discoveryOrganizations;
-*/
-import "C"
-
-import (
- "unsafe"
-
- "github.com/eduvpn/eduvpn-common/client"
- "github.com/eduvpn/eduvpn-common/types"
-)
-
-func getCPtrDiscoOrganization(
- state *client.Client,
- organization *types.DiscoveryOrganization,
-) *C.discoveryOrganization {
- returnedStruct := (*C.discoveryOrganization)(
- C.malloc(C.size_t(unsafe.Sizeof(C.discoveryOrganization{}))),
- )
- returnedStruct.display_name = C.CString(state.GetTranslated(organization.DisplayName))
- returnedStruct.org_id = C.CString(organization.OrgID)
- returnedStruct.secure_internet_home = C.CString(organization.SecureInternetHome)
- returnedStruct.keyword_list = C.CString(state.GetTranslated(organization.KeywordList))
- return returnedStruct
-}
-
-func getCPtrDiscoOrganizations(
- state *client.Client,
- organizations *types.DiscoveryOrganizations,
-) (C.size_t, **C.discoveryOrganization) {
- totalOrganizations := C.size_t(len(organizations.List))
- if totalOrganizations > 0 {
- organizationsPtr := (**C.discoveryOrganization)(
- C.malloc(totalOrganizations * C.size_t(unsafe.Sizeof(uintptr(0)))),
- )
- cOrganizations := unsafe.Slice(organizationsPtr, totalOrganizations)
- index := 0
- for _, organization := range organizations.List {
- cOrganization := getCPtrDiscoOrganization(state, &organization)
- cOrganizations[index] = cOrganization
- index++
- }
- return totalOrganizations, organizationsPtr
- }
- return 0, nil
-}
-
-func getCPtrDiscoServer(
- state *client.Client,
- server *types.DiscoveryServer,
-) *C.discoveryServer {
- returnedStruct := (*C.discoveryServer)(
- C.malloc(C.size_t(unsafe.Sizeof(C.discoveryServer{}))),
- )
- returnedStruct.authentication_url_template = C.CString(server.AuthenticationURLTemplate)
- returnedStruct.base_url = C.CString(server.BaseURL)
- returnedStruct.country_code = C.CString(server.CountryCode)
- returnedStruct.display_name = C.CString(state.GetTranslated(server.DisplayName))
- returnedStruct.keyword_list = C.CString(state.GetTranslated(server.KeywordList))
- returnedStruct.total_public_keys, returnedStruct.public_key_list = getCPtrListStrings(
- server.PublicKeyList,
- )
- returnedStruct.server_type = C.CString(server.Type)
- returnedStruct.total_support_contact, returnedStruct.support_contact = getCPtrListStrings(
- server.SupportContact,
- )
- return returnedStruct
-}
-
-func getCPtrDiscoServers(
- state *client.Client,
- servers *types.DiscoveryServers,
-) (C.size_t, **C.discoveryServer) {
- totalServers := C.size_t(len(servers.List))
- if totalServers > 0 {
- serversPtr := (**C.discoveryServer)(
- C.malloc(totalServers * C.size_t(unsafe.Sizeof(uintptr(0)))),
- )
- cServers := unsafe.Slice(serversPtr, totalServers)
- index := 0
- for _, server := range servers.List {
- cServer := getCPtrDiscoServer(state, &server)
- cServers[index] = cServer
- index++
- }
- return totalServers, serversPtr
- }
- return 0, nil
-}
-
-func freeDiscoOrganization(cOrganization *C.discoveryOrganization) {
- C.free(unsafe.Pointer(cOrganization.display_name))
- C.free(unsafe.Pointer(cOrganization.org_id))
- C.free(unsafe.Pointer(cOrganization.secure_internet_home))
- C.free(unsafe.Pointer(cOrganization.keyword_list))
- C.free(unsafe.Pointer(cOrganization))
-}
-
-func freeDiscoServer(cServer *C.discoveryServer) {
- C.free(unsafe.Pointer(cServer.authentication_url_template))
- C.free(unsafe.Pointer(cServer.base_url))
- C.free(unsafe.Pointer(cServer.country_code))
- C.free(unsafe.Pointer(cServer.display_name))
- C.free(unsafe.Pointer(cServer.keyword_list))
- freeCListStrings(cServer.public_key_list, cServer.total_public_keys)
- C.free(unsafe.Pointer(cServer.server_type))
- freeCListStrings(cServer.support_contact, cServer.total_support_contact)
- C.free(unsafe.Pointer(cServer))
-}
-
-//export FreeDiscoServers
-func FreeDiscoServers(cServers *C.discoveryServers) {
- if cServers.total_servers > 0 {
- servers := unsafe.Slice(cServers.servers, cServers.total_servers)
- for i := C.size_t(0); i < cServers.total_servers; i++ {
- freeDiscoServer(servers[i])
- }
- C.free(unsafe.Pointer(cServers.servers))
- }
- C.free(unsafe.Pointer(cServers))
-}
-
-//export FreeDiscoOrganizations
-func FreeDiscoOrganizations(cOrganizations *C.discoveryOrganizations) {
- if cOrganizations.total_organizations > 0 {
- organizations := unsafe.Slice(cOrganizations.organizations, cOrganizations.total_organizations)
- for i := C.size_t(0); i < cOrganizations.total_organizations; i++ {
- freeDiscoOrganization(organizations[i])
- }
- C.free(unsafe.Pointer(cOrganizations.organizations))
- }
- C.free(unsafe.Pointer(cOrganizations))
-}
-
-//export GetDiscoServers
-func GetDiscoServers(name *C.char) (*C.discoveryServers, *C.error) {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return nil, getError(stateErr)
- }
- servers, serversErr := state.DiscoServers()
- // if we get no servers then we return immediately
- if servers == nil {
- return nil, getError(serversErr)
- }
- returnedStruct := (*C.discoveryServers)(
- C.malloc(C.size_t(unsafe.Sizeof(C.discoveryServers{}))),
- )
- returnedStruct.version = C.ulonglong(servers.Version)
- returnedStruct.total_servers, returnedStruct.servers = getCPtrDiscoServers(
- state,
- servers,
- )
- return returnedStruct, getError(serversErr)
-}
-
-//export GetDiscoOrganizations
-func GetDiscoOrganizations(name *C.char) (*C.discoveryOrganizations, *C.error) {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return nil, getError(stateErr)
- }
- organizations, organizationsErr := state.DiscoOrganizations()
- // if we get no organizations then we return immediately
- if organizations == nil {
- return nil, getError(organizationsErr)
- }
- returnedStruct := (*C.discoveryOrganizations)(
- C.malloc(C.size_t(unsafe.Sizeof(C.discoveryOrganizations{}))),
- )
-
- returnedStruct.version = C.ulonglong(organizations.Version)
- returnedStruct.total_organizations, returnedStruct.organizations = getCPtrDiscoOrganizations(
- state,
- organizations,
- )
-
- return returnedStruct, getError(organizationsErr)
-}
diff --git a/exports/exports.go b/exports/exports.go
index c09e980..d8fd6ea 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -6,114 +6,87 @@ package main
#include "server.h"
typedef long long int (*ReadRxBytes)();
-typedef struct token {
- const char* access;
- const char* refresh;
- unsigned long long int expired;
-} token;
-typedef void (*UpdateToken)(const char* name, server* srv, token* tok);
-
-typedef struct configData {
- const char* config;
- const char* config_type;
- token* tokens;
-} configData;
-
-typedef int (*PythonCB)(const char* name, int oldstate, int newstate, void* data);
+typedef int (*StateCB)(int oldstate, int newstate, void* data);
static long long int get_read_rx_bytes(ReadRxBytes read)
{
return read();
}
-
-static void update_token(UpdateToken func, const char* name, server* srv, token* tok)
-{
- func(name, srv, tok);
-}
-
-static int call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data)
+static int call_callback(StateCB callback, int oldstate, int newstate, void* data)
{
- return callback(name, oldstate, newstate, data);
+ return callback(oldstate, newstate, data);
}
*/
import "C"
import (
- "time"
+ "encoding/json"
"unsafe"
- "github.com/eduvpn/eduvpn-common/internal/log"
- "github.com/eduvpn/eduvpn-common/internal/oauth"
- "github.com/eduvpn/eduvpn-common/internal/server"
"github.com/go-errors/errors"
"github.com/eduvpn/eduvpn-common/client"
+ "github.com/eduvpn/eduvpn-common/types"
)
-var PStateCallbacks map[string]C.PythonCB
+var (
+ PStateCallback C.StateCB
+ VPNState *client.Client
+)
-var VPNStates map[string]*client.Client
+func getTokens(tokens *C.char) (t types.Tokens, err error) {
+ err = json.Unmarshal([]byte(C.GoString(tokens)), &t)
+ return t, err
+}
-func GetStateData(
- state *client.Client,
- stateID client.FSMStateID,
- data interface{},
-) unsafe.Pointer {
- switch stateID {
- case client.StateNoServer:
- return (unsafe.Pointer)(getTransitionDataServers(state, data))
- case client.StateOAuthStarted:
- if converted, ok := data.(string); ok {
- return (unsafe.Pointer)(C.CString(converted))
- }
- case client.StateAskLocation:
- return (unsafe.Pointer)(getTransitionSecureLocations(data))
- case client.StateAskProfile:
- return (unsafe.Pointer)(getTransitionProfiles(data))
- case client.StateDisconnected:
- return (unsafe.Pointer)(getTransitionServer(state, data))
- case client.StateDisconnecting:
- return (unsafe.Pointer)(getTransitionServer(state, data))
- case client.StateConnecting:
- return (unsafe.Pointer)(getTransitionServer(state, data))
- case client.StateConnected:
- return (unsafe.Pointer)(getTransitionServer(state, data))
- default:
- return nil
+func getCError(err error) *C.char {
+ if err == nil {
+ return C.CString("")
}
- return nil
+ return C.CString(err.Error())
+}
+
+func getReturnData(data interface{}) (string, error) {
+ // If it is already a string return directly
+ if x, ok := data.(string); ok {
+ return x, nil
+ }
+
+ // Otherwise use JSON
+ b, err := json.Marshal(data)
+ if err != nil {
+ return "", err
+ }
+ return string(b), nil
}
func StateCallback(
state *client.Client,
- name string,
oldState client.FSMStateID,
newState client.FSMStateID,
data interface{},
) bool {
- PStateCallback, exists := PStateCallbacks[name]
- if !exists || PStateCallback == nil {
+ if PStateCallback == nil {
return false
}
- nameC := C.CString(name)
oldStateC := C.int(oldState)
newStateC := C.int(newState)
- dataC := GetStateData(state, newState, data)
- handled := C.call_callback(PStateCallback, nameC, oldStateC, newStateC, dataC)
- C.free(unsafe.Pointer(nameC))
- // data_c gets freed by the wrapper
+ d, err := getReturnData(data)
+ if err != nil {
+ return false
+ }
+ dataC := C.CString(d)
+ handled := C.call_callback(PStateCallback, oldStateC, newStateC, unsafe.Pointer(dataC))
+ FreeString(dataC)
return handled == C.int(1)
}
-func GetVPNState(name string) (*client.Client, error) {
- state, exists := VPNStates[name]
-
- if !exists || state == nil {
- return nil, errors.Errorf("state with name %s not found", name)
+func getVPNState() (*client.Client, error) {
+ if VPNState == nil {
+ return nil, errors.New("No state available, did you register the client?")
}
-
- return state, nil
+ return VPNState, nil
}
//export Register
@@ -121,39 +94,44 @@ func Register(
name *C.char,
version *C.char,
configDirectory *C.char,
- language *C.char,
- stateCallback C.PythonCB,
+ stateCallback C.StateCB,
debug C.int,
-) *C.error {
- nameStr := C.GoString(name)
- versionStr := C.GoString(version)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- state = &client.Client{}
- }
- if VPNStates == nil {
- VPNStates = make(map[string]*client.Client)
- }
- if PStateCallbacks == nil {
- PStateCallbacks = make(map[string]C.PythonCB)
- }
- VPNStates[nameStr] = state
- PStateCallbacks[nameStr] = stateCallback
+) *C.char {
+ state, stateErr := getVPNState()
+ if stateErr == nil {
+ return getCError(errors.New("failed to register, a VPN state is already present"))
+ }
+ state = &client.Client{}
+ VPNState = state
+ PStateCallback = stateCallback
registerErr := state.Register(
- nameStr,
- versionStr,
+ C.GoString(name),
+ C.GoString(version),
C.GoString(configDirectory),
- C.GoString(language),
func(old client.FSMStateID, new client.FSMStateID, data interface{}) bool {
- return StateCallback(state, nameStr, old, new, data)
+ return StateCallback(state, old, new, data)
},
- debug != 0,
+ debug == 1,
)
- if registerErr != nil {
- delete(VPNStates, nameStr)
+ return getCError(registerErr)
+}
+
+//export ExpiryTimes
+func ExpiryTimes() (*C.char, *C.char) {
+ state, stateErr := getVPNState()
+ if stateErr != nil {
+ return nil, getCError(stateErr)
+ }
+ exp, err := state.ExpiryTimes()
+ if err != nil {
+ return nil, getCError(err)
+ }
+ ret, err := getReturnData(exp)
+ if err != nil {
+ return nil, getCError(err)
}
- return getError(registerErr)
+ return C.CString(ret), nil
}
//export SetTokenUpdater
@@ -179,362 +157,216 @@ func SetTokenUpdater(name *C.char, updater C.UpdateToken) *C.error {
}
//export Deregister
-func Deregister(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+func Deregister() *C.char {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
+ return getCError(stateErr)
}
state.Deregister()
+ VPNState = nil
return nil
}
-func getError(err error) *C.error {
- if err == nil {
- return nil
- }
- errorStruct := (*C.error)(
- C.malloc(C.size_t(unsafe.Sizeof(C.error{}))),
- )
- if err1, ok := err.(*errors.Error); ok {
- if err1 == nil {
- errorStruct.traceback = C.CString("N/A")
- errorStruct.cause = C.CString("unknown error")
- return errorStruct
- }
- errorStruct.traceback = C.CString(err1.ErrorStack())
- if err1.Err == nil {
- errorStruct.cause = C.CString(err1.Error())
- } else {
- errorStruct.cause = C.CString(err1.Err.Error())
- }
- } else {
- errorStruct.traceback = C.CString("N/A")
- errorStruct.cause = C.CString(err.Error())
- }
- return errorStruct
-}
-
-//export FreeError
-func FreeError(err *C.error) {
- C.free(unsafe.Pointer(err.traceback))
- C.free(unsafe.Pointer(err.cause))
- C.free(unsafe.Pointer(err))
-}
-
//export CancelOAuth
-func CancelOAuth(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+func CancelOAuth() *C.char {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
+ return getCError(stateErr)
}
cancelErr := state.CancelOAuth()
- return getError(cancelErr)
+ return getCError(cancelErr)
}
-//export RemoveSecureInternet
-func RemoveSecureInternet(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+//export AddServer
+func AddServer(_type *C.char, id *C.char) *C.char {
+ // TODO: type
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
- }
- removeErr := state.RemoveSecureInternet()
- return getError(removeErr)
-}
-
-//export AddInstituteAccess
-func AddInstituteAccess(name *C.char, url *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return getError(stateErr)
+ return getCError(stateErr)
+ }
+ t := C.GoString(_type)
+ var err error
+ switch t {
+ case "institute_access":
+ err = state.AddInstituteServer(C.GoString(id))
+ case "secure_internet":
+ err = state.AddSecureInternetHomeServer(C.GoString(id))
+ case "custom_server":
+ err = state.AddCustomServer(C.GoString(id))
+ default:
+ err = errors.Errorf("invalid type: %v", t)
}
- // FIXME: Return server result
- _, addErr := state.AddInstituteServer(C.GoString(url))
- return getError(addErr)
+ return getCError(err)
}
-//export AddSecureInternetHomeServer
-func AddSecureInternetHomeServer(name *C.char, orgID *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+//export RemoveServer
+func RemoveServer(_type *C.char, id *C.char) *C.char {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
- }
- // FIXME: Return server result
- _, addErr := state.AddSecureInternetHomeServer(C.GoString(orgID))
- return getError(addErr)
-}
-
-//export AddCustomServer
-func AddCustomServer(name *C.char, url *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return getError(stateErr)
+ return getCError(stateErr)
+ }
+ t := C.GoString(_type)
+ var err error
+ switch t {
+ case "institute_access":
+ err = state.RemoveInstituteAccess(C.GoString(id))
+ case "secure_internet":
+ err = state.RemoveSecureInternet()
+ case "custom_server":
+ err = state.RemoveCustomServer(C.GoString(id))
+ default:
+ err = errors.Errorf("invalid type: %v", t)
}
- // FIXME: Return server result
- _, addErr := state.AddCustomServer(C.GoString(url))
- return getError(addErr)
+ return getCError(err)
}
-//export RemoveInstituteAccess
-func RemoveInstituteAccess(name *C.char, url *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+//export CurrentServer
+func CurrentServer() (*C.char, *C.char) {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
+ return nil, getCError(stateErr)
}
- removeErr := state.RemoveInstituteAccess(C.GoString(url))
- return getError(removeErr)
-}
-
-//export RemoveCustomServer
-func RemoveCustomServer(name *C.char, url *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return getError(stateErr)
+ srv, err := state.CurrentServer()
+ if err != nil {
+ return nil, getCError(err)
}
- removeErr := state.RemoveCustomServer(C.GoString(url))
- return getError(removeErr)
-}
-
-func cToken(t oauth.Token) *C.token {
- cTok := (*C.token)(C.malloc(C.size_t(unsafe.Sizeof(C.token{}))))
- cTok.access = C.CString(t.Access)
- cTok.refresh = C.CString(t.Refresh)
- cTok.expired = C.ulonglong(t.ExpiredTimestamp.Unix())
- return cTok
-}
-
-func cConfig(config *client.ConfigData) *C.configData {
- // No config so return nil pointer
- if config == nil {
- return nil
+ ret, err := getReturnData(srv)
+ if err != nil {
+ return nil, getCError(err)
}
- cConf := (*C.configData)(C.malloc(C.size_t(unsafe.Sizeof(C.configData{}))))
- cConf.config = C.CString(config.Config)
- cConf.config_type = C.CString(config.Type)
- cConf.tokens = cToken(config.Tokens)
- return cConf
+ return C.CString(ret), nil
}
-//export FreeTokens
-func FreeTokens(tokens *C.token) {
- C.free(unsafe.Pointer(tokens.access))
- C.free(unsafe.Pointer(tokens.refresh))
- C.free(unsafe.Pointer(tokens))
-}
-
-//export FreeConfig
-func FreeConfig(config *C.configData) {
- C.free(unsafe.Pointer(config.config))
- C.free(unsafe.Pointer(config.config_type))
- C.free(unsafe.Pointer(config.tokens.access))
- C.free(unsafe.Pointer(config.tokens.refresh))
- C.free(unsafe.Pointer(config.tokens))
- C.free(unsafe.Pointer(config))
-}
-
-//export GetConfigSecureInternet
-func GetConfigSecureInternet(
- name *C.char,
- orgID *C.char,
- preferTCP C.int,
- prevTokens C.token,
-) (*C.configData, *C.error) {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+//export ServerList
+func ServerList() (*C.char, *C.char) {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return nil, getError(stateErr)
- }
- preferTCPBool := preferTCP == 1
- t := oauth.Token{
- Access: C.GoString(prevTokens.access),
- Refresh: C.GoString(prevTokens.refresh),
- ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0),
+ return nil, getCError(stateErr)
}
- cfg, configErr := state.GetConfigSecureInternet(C.GoString(orgID), preferTCPBool, t)
- return cConfig(cfg), getError(configErr)
-}
-
-//export GetConfigInstituteAccess
-func GetConfigInstituteAccess(
- name *C.char,
- url *C.char,
- preferTCP C.int,
- prevTokens C.token,
-) (*C.configData, *C.error) {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return nil, getError(stateErr)
+ list, err := state.ServerList()
+ if err != nil {
+ return nil, getCError(err)
}
- preferTCPBool := preferTCP == 1
- t := oauth.Token{
- Access: C.GoString(prevTokens.access),
- Refresh: C.GoString(prevTokens.refresh),
- ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0),
+ ret, err := getReturnData(list)
+ if err != nil {
+ return nil, getCError(err)
}
- cfg, configErr := state.GetConfigInstituteAccess(C.GoString(url), preferTCPBool, t)
- return cConfig(cfg), getError(configErr)
+ return C.CString(ret), nil
}
-//export GetConfigCustomServer
-func GetConfigCustomServer(
- name *C.char,
- url *C.char,
- preferTCP C.int,
- prevTokens C.token,
-) (*C.configData, *C.error) {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+//export GetConfig
+func GetConfig(_type *C.char, id *C.char, pTCP C.int, tokens *C.char) (*C.char, *C.char) {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return nil, getError(stateErr)
+ return nil, getCError(stateErr)
+ }
+ preferTCPBool := pTCP == 1
+ tok, err := getTokens(tokens)
+ if err != nil {
+ return nil, getCError(err)
+ }
+ t := C.GoString(_type)
+ var cfg *types.Configuration
+ switch t {
+ case "institute_access":
+ cfg, err = state.GetConfigInstituteAccess(C.GoString(id), preferTCPBool, tok)
+ case "secure_internet":
+ cfg, err = state.GetConfigSecureInternet(C.GoString(id), preferTCPBool, tok)
+ case "custom_server":
+ cfg, err = state.GetConfigCustomServer(C.GoString(id), preferTCPBool, tok)
+ default:
+ err = errors.Errorf("invalid type: %v", t)
}
- preferTCPBool := preferTCP == 1
- t := oauth.Token{
- Access: C.GoString(prevTokens.access),
- Refresh: C.GoString(prevTokens.refresh),
- ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0),
+ if cfg != nil && err == nil {
+ d, err := getReturnData(cfg)
+ if err == nil {
+ return C.CString(d), nil
+ }
}
- cfg, configErr := state.GetConfigCustomServer(C.GoString(url), preferTCPBool, t)
- return cConfig(cfg), getError(configErr)
+ return nil, getCError(err)
}
//export SetProfileID
-func SetProfileID(name *C.char, data *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+func SetProfileID(data *C.char) *C.char {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
+ return getCError(stateErr)
}
profileErr := state.SetProfileID(C.GoString(data))
- return getError(profileErr)
-}
-
-//export ChangeSecureLocation
-func ChangeSecureLocation(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return getError(stateErr)
- }
- locationErr := state.ChangeSecureLocation()
- return getError(locationErr)
+ return getCError(profileErr)
}
//export SetSecureLocation
-func SetSecureLocation(name *C.char, data *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+func SetSecureLocation(data *C.char) *C.char {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
+ return getCError(stateErr)
}
locationErr := state.SetSecureLocation(C.GoString(data))
- return getError(locationErr)
-}
-
-//export GoBack
-func GoBack(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return getError(stateErr)
- }
- goBackErr := state.GoBack()
- return getError(goBackErr)
+ return getCError(locationErr)
}
-//export SetSearchServer
-func SetSearchServer(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+//export DiscoServers
+func DiscoServers() (*C.char, *C.char) {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
+ return nil, getCError(stateErr)
}
- setSearchErr := state.SetSearchServer()
- return getError(setSearchErr)
-}
-
-//export Cleanup
-func Cleanup(name *C.char, prevTokens C.token) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return getError(stateErr)
+ servers, err := state.DiscoServers()
+ if servers == nil && err != nil {
+ return nil, getCError(err)
}
- t := oauth.Token{
- Access: C.GoString(prevTokens.access),
- Refresh: C.GoString(prevTokens.refresh),
- ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0),
+ s, reterr := getReturnData(servers)
+ if reterr != nil {
+ return nil, getCError(reterr)
}
- err := state.Cleanup(t)
- return getError(err)
+ return C.CString(s), getCError(err)
}
-//export SetDisconnected
-func SetDisconnected(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+//export DiscoOrganizations
+func DiscoOrganizations() (*C.char, *C.char) {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
+ return nil, getCError(stateErr)
}
- setDisconnectedErr := state.SetDisconnected()
- return getError(setDisconnectedErr)
-}
-
-//export SetDisconnecting
-func SetDisconnecting(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return getError(stateErr)
+ orgs, err := state.DiscoOrganizations()
+ if orgs == nil && err != nil {
+ return nil, getCError(err)
}
- setDisconnectingErr := state.SetDisconnecting()
- return getError(setDisconnectingErr)
-}
-
-//export SetConnecting
-func SetConnecting(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return getError(stateErr)
+ s, reterr := getReturnData(orgs)
+ if reterr != nil {
+ return nil, getCError(reterr)
}
- setConnectingErr := state.SetConnecting()
- return getError(setConnectingErr)
+ return C.CString(s), getCError(err)
}
-//export SetConnected
-func SetConnected(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+//export Cleanup
+func Cleanup(prevTokens *C.char) *C.char {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
+ return getCError(stateErr)
}
- setConnectedErr := state.SetConnected()
- return getError(setConnectedErr)
+ t, err := getTokens(prevTokens)
+ if err != nil {
+ return getCError(err)
+ }
+ err = state.Cleanup(t)
+ return getCError(err)
}
//export RenewSession
-func RenewSession(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+func RenewSession() *C.char {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
+ return getCError(stateErr)
}
renewSessionErr := state.RenewSession()
- return getError(renewSessionErr)
+ return getCError(renewSessionErr)
}
//export ShouldRenewButton
-func ShouldRenewButton(name *C.char) C.int {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+func ShouldRenewButton() C.int {
+ state, stateErr := getVPNState()
if stateErr != nil {
return C.int(0)
}
@@ -545,47 +377,45 @@ func ShouldRenewButton(name *C.char) C.int {
return C.int(0)
}
-//export InFSMState
-func InFSMState(name *C.char, checkState C.int) C.int {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+//export SetSupportWireguard
+func SetSupportWireguard(support C.int) *C.char {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return C.int(0)
+ return getCError(stateErr)
}
- inStateBool := state.InFSMState(client.FSMStateID(checkState))
- if inStateBool {
- return C.int(1)
- }
- return C.int(0)
+ state.SupportsWireguard = support == 1
+ return nil
}
-//export SetSupportWireguard
-func SetSupportWireguard(name *C.char, support C.int) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+//export SecureLocationList
+func SecureLocationList() (*C.char, *C.char) {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
+ return nil, getCError(stateErr)
}
- state.SupportsWireguard = support == 1
- return nil
+ locs := state.Discovery.SecureLocationList()
+ l, err := getReturnData(locs)
+ if err != nil {
+ return nil, getCError(err)
+ }
+ return C.CString(l), nil
}
//export StartFailover
-func StartFailover(name *C.char, gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int, *C.error) {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+func StartFailover(gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int, *C.char) {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return C.int(0), getError(stateErr)
+ return C.int(0), getCError(stateErr)
}
dropped, droppedErr := state.StartFailover(C.GoString(gateway), int(mtu), func() (int64, error) {
rxBytes := int64(C.get_read_rx_bytes(readRxBytes))
- if rxBytes == -1 {
+ if rxBytes < 0 {
return 0, errors.New("client gave an invalid rx bytes value")
}
return rxBytes, nil
})
if droppedErr != nil {
- return C.int(0), getError(droppedErr)
+ return C.int(0), getCError(droppedErr)
}
droppedC := C.int(0)
if dropped {
@@ -595,15 +425,14 @@ func StartFailover(name *C.char, gateway *C.char, mtu C.int, readRxBytes C.ReadR
}
//export CancelFailover
-func CancelFailover(name *C.char) *C.error {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
+func CancelFailover() *C.char {
+ state, stateErr := getVPNState()
if stateErr != nil {
- return getError(stateErr)
+ return getCError(stateErr)
}
cancelErr := state.CancelFailover()
if cancelErr != nil {
- return getError(cancelErr)
+ return getCError(cancelErr)
}
return nil
}
diff --git a/exports/servers.go b/exports/servers.go
deleted file mode 100644
index 73b8b6c..0000000
--- a/exports/servers.go
+++ /dev/null
@@ -1,342 +0,0 @@
-package main
-
-/*
-// for free and size_t
-#include <stdlib.h>
-#include "error.h"
-#include "server.h"
-*/
-import "C"
-
-import (
- "unsafe"
-
- "github.com/eduvpn/eduvpn-common/client"
- "github.com/eduvpn/eduvpn-common/internal/server"
-)
-
-// Get the pointer to the C struct for the profile
-// We allocate the struct, the profile ID and the display name
-func getCPtrProfile(profile *server.Profile) *C.serverProfile {
- // Allocate the struct using malloc and the size of the struct
- cProfile := (*C.serverProfile)(C.malloc(C.size_t(unsafe.Sizeof(C.serverProfile{}))))
- cProfile.id = C.CString(profile.ID)
- cProfile.display_name = C.CString(profile.DisplayName)
- if profile.DefaultGateway {
- cProfile.default_gateway = C.int(1)
- } else {
- cProfile.default_gateway = C.int(0)
- }
-
- return cProfile
-}
-
-// Get the pointer to the C struct for the profiles
-// We allocate the struct and the struct inside it for the profiles
-func getCPtrProfiles(serverProfiles *server.ProfileInfo) *C.serverProfiles {
- goProfiles := serverProfiles.Info.ProfileList
- // Allocate the profles struct using malloc and the size of a pointer
- cProfiles := (*C.serverProfiles)(C.malloc(C.size_t(uintptr(0))))
- totalProfiles := C.size_t(len(goProfiles))
- // Defaults if we have no profiles
- cProfiles.current = C.int(0)
- cProfiles.profiles = nil
- cProfiles.total_profiles = totalProfiles
- // If we have profiles (which we should), we allocate the struct with malloc and the size of a pointer
- // We then fill the struct by converting it to a go slice and get a C pointer for each profile
- if totalProfiles > 0 {
- profilesPtr := (**C.serverProfile)(C.malloc(totalProfiles * C.size_t(unsafe.Sizeof(uintptr(0)))))
- profiles := unsafe.Slice(profilesPtr, totalProfiles)
- index := 0
- for _, profile := range goProfiles {
- profiles[index] = getCPtrProfile(&profile)
- index++
- }
- cProfiles.current = C.int(serverProfiles.CurrentProfileIndex())
- cProfiles.profiles = (**C.serverProfile)(profilesPtr)
- }
- return cProfiles
-}
-
-// Free the profiles by looping through them if there are any
-// Also free the pointer itself
-//
-//export FreeProfiles
-func FreeProfiles(profiles *C.serverProfiles) {
- // We should only free the profiles if we have them (which we should)
- if profiles.total_profiles > 0 {
- // Convert it to a go slice
- profilesSlice := unsafe.Slice(profiles.profiles, profiles.total_profiles)
- // Loop through the pointers and free th allocated strings and the struct itself
- for i := C.size_t(0); i < profiles.total_profiles; i++ {
- C.free(unsafe.Pointer(profilesSlice[i].id))
- C.free(unsafe.Pointer(profilesSlice[i].display_name))
- C.free(unsafe.Pointer(profilesSlice[i]))
- }
- // Free the inner profiles struct
- C.free(unsafe.Pointer(profiles.profiles))
- }
- // Free the profiles struct itself
- C.free(unsafe.Pointer(profiles))
-}
-
-// Get a list of strings with a size as a c structure
-// Returns the size in size_t and the list of strings as a double pointer char
-func getCPtrListStrings(allStrings []string) (C.size_t, **C.char) {
- // Get the total strings in size_t
- totalStrings := C.size_t(len(allStrings))
-
- // If we have strings
- // Allocate memory for the strings array
- if totalStrings > 0 {
- stringsPtr := (**C.char)(C.malloc(totalStrings * C.size_t(unsafe.Sizeof(uintptr(0)))))
- // Go slice conversion
- cStrings := unsafe.Slice(stringsPtr, totalStrings)
-
- // Loop through and allocate the string for each contact
- for index, string := range allStrings {
- cStrings[index] = C.CString(string)
- }
- return totalStrings, (**C.char)(stringsPtr)
- }
-
- // No strings then the length is zero and the char array is nil
- return C.size_t(0), nil
-}
-
-// Function for freeing an array/list of strings
-// It takes the strings as a pointer to a string and the total strings in size_t
-func freeCListStrings(allStrings **C.char, totalStrings C.size_t) {
- // If we have strings we should free them
- // By converting to a Go slice, and freeing them ony by one
- // At last free the pointer itself
- if totalStrings > 0 {
- stringsSlice := unsafe.Slice(allStrings, totalStrings)
- for i := C.size_t(0); i < totalStrings; i++ {
- C.free(unsafe.Pointer(stringsSlice[i]))
- }
- C.free(unsafe.Pointer(allStrings))
- }
-}
-
-// Function for getting the server,
-// It gets the main state as a pointer as we need to convert some string maps to localized strings
-// It gets the base information for a server as well
-func getCPtrServer(state *client.Client, base *client.ServerBase) *C.server {
- // Allocation using malloc and the size of the struct
- cServer := (*C.server)(C.malloc(C.size_t(unsafe.Sizeof(C.server{}))))
- // String allocation and translate the display name
- identifier := base.URL
- countryCode := ""
- // A secure internet server has multiple locations
- locations := []string{}
- if base.Type == "secure_internet" {
- identifier = state.Servers.SecureInternetHomeServer.HomeOrganizationID
- countryCode = state.Servers.SecureInternetHomeServer.CurrentLocation
- locations = state.Discovery.SecureLocationList()
- }
-
- cServer.identifier = C.CString(identifier)
- cServer.display_name = C.CString(state.GetTranslated(base.DisplayName))
- cServer.country_code = C.CString(countryCode)
- cServer.server_type = C.CString(base.Type)
- // Call the helper to get the list of support contacts
- cServer.total_support_contact, cServer.support_contact = getCPtrListStrings(
- base.SupportContact,
- )
- locationsStruct := (*C.serverLocations)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{}))))
- locationsStruct.total_locations, locationsStruct.locations = getCPtrListStrings(locations)
- cServer.locations = locationsStruct
-
- profiles := base.ValidProfiles(state.SupportsWireguard)
- cServer.profiles = getCPtrProfiles(&profiles)
- // No endtime is given if we get servers when it has been partially initialised
- if base.EndTime.IsZero() {
- cServer.expire_time = C.ulonglong(0)
- } else {
- // The expire time should be stored as an unsigned long long in unix time
- cServer.expire_time = C.ulonglong(base.EndTime.Unix())
- }
- return cServer
-}
-
-// Function for freeing a single server
-// Gets the pointer to C struct
-//
-//export FreeServer
-func FreeServer(info *C.server) {
- // Free strings
- C.free(unsafe.Pointer(info.identifier))
- C.free(unsafe.Pointer(info.display_name))
- C.free(unsafe.Pointer(info.country_code))
- C.free(unsafe.Pointer(info.server_type))
-
- // Free arrays
- freeCListStrings(info.support_contact, info.total_support_contact)
- FreeSecureLocations(info.locations)
- FreeProfiles(info.profiles)
-
- // Free the struct itself
- C.free(unsafe.Pointer(info))
-}
-
-// Get the C ptr to the servers, returns the length in size_t and the double pointer to the struct
-func getCPtrServers(
- state *client.Client,
- serverMap map[string]*server.InstituteAccessServer,
-) (C.size_t, **C.server) {
- totalServers := C.size_t(len(serverMap))
- // If we have servers, which is not always the case
- if totalServers > 0 {
- serversPtr := (**C.server)(C.malloc(totalServers * C.size_t(unsafe.Sizeof(uintptr(0)))))
- servers := unsafe.Slice(serversPtr, totalServers)
- index := 0
- for _, currentServer := range serverMap {
- cServer := getCPtrServer(state, &currentServer.Basic)
- servers[index] = cServer
- index++
- }
- return totalServers, serversPtr
- }
- return C.size_t(0), nil
-}
-
-// This function takes the servers as a C struct pointer as input
-// It frees all allocated memory for the server
-//
-//export FreeServers
-func FreeServers(cServers *C.servers) {
- // Free the custom servers if there are any
- if cServers.total_custom > 0 {
- customServers := unsafe.Slice(cServers.custom_servers, cServers.total_custom)
- for i := C.size_t(0); i < cServers.total_custom; i++ {
- FreeServer(customServers[i])
- }
- C.free(unsafe.Pointer(cServers.custom_servers))
- }
- // Free the institute access servers if there are any
- if cServers.total_institute > 0 {
- instituteServers := unsafe.Slice(cServers.institute_servers, cServers.total_institute)
-
- for i := C.size_t(0); i < cServers.total_institute; i++ {
- FreeServer(instituteServers[i])
- }
- C.free(unsafe.Pointer(cServers.institute_servers))
- }
- // Free the secure internet server if there is one
- if cServers.secure_internet_server != nil {
- FreeServer(cServers.secure_internet_server)
- }
- // Free the structure itself
- C.free(unsafe.Pointer(cServers))
-}
-
-// Return the servers as a C struct pointer
-// It takes the state as a pointer as we need to translate some strings
-// It also takes the servers as a pointer that belongs to the main state or gathered from the callback
-func getSavedServersWithOptions(state *client.Client, servers *server.Servers) *C.servers {
- // Allocate the struct that we will return
- // With the size of the c struct
- returnedStruct := (*C.servers)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{}))))
-
- // Get the different categories of servers
- totalCustom, customPtr := getCPtrServers(state, servers.CustomServers.Map)
- totalInstitute, institutePtr := getCPtrServers(state, servers.InstituteServers.Map)
- var secureServerPtr *C.server
- secureInternetBase, secureInternetBaseErr := servers.SecureInternetHomeServer.Base()
- if secureInternetBaseErr == nil && secureInternetBase != nil {
- // FIXME: log error?
- secureServerPtr = getCPtrServer(state, secureInternetBase)
- // Give a new identifier
- C.free(unsafe.Pointer(secureServerPtr.identifier))
- secureServerPtr.identifier = C.CString(servers.SecureInternetHomeServer.HomeOrganizationID)
- secureServerPtr.country_code = C.CString(servers.SecureInternetHomeServer.CurrentLocation)
- }
-
- // Fill the struct and return
- returnedStruct.custom_servers = customPtr
- returnedStruct.total_custom = totalCustom
- returnedStruct.institute_servers = institutePtr
- returnedStruct.total_institute = totalInstitute
- returnedStruct.secure_internet_server = secureServerPtr
- return returnedStruct
-}
-
-// This function takes the name as input which is the name of the client
-// It gets the state by name and then returns the saved servers as a c struct belonging to it
-//
-//export GetSavedServers
-func GetSavedServers(name *C.char) (*C.servers, *C.error) {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return nil, getError(stateErr)
- }
- servers := getSavedServersWithOptions(state, &state.Servers)
- return servers, nil
-}
-
-// This function takes the name as input which is the name of the client
-// It gets the state by name and then returns the current server as a c struct belonging to it
-//
-//export GetCurrentServer
-func GetCurrentServer(name *C.char) (*C.server, *C.error) {
- nameStr := C.GoString(name)
- state, stateErr := GetVPNState(nameStr)
- if stateErr != nil {
- return nil, getError(stateErr)
- }
- server, serverErr := state.Servers.GetCurrentServer()
- if serverErr != nil {
- return nil, getError(serverErr)
- }
- base, baseErr := server.Base()
- if baseErr != nil {
- return nil, getError(baseErr)
- }
- cServer := getCPtrServer(state, base)
- return cServer, nil
-}
-
-// This function takes the state as input which is the main state
-// It also takes the data as an interface and if it has the servers type gets the data as a c struct otherwise nil
-func getTransitionDataServers(state *client.Client, data interface{}) *C.servers {
- if converted, ok := data.(server.Servers); ok {
- return getSavedServersWithOptions(state, &converted)
- }
- return nil
-}
-
-//export FreeSecureLocations
-func FreeSecureLocations(locations *C.serverLocations) {
- freeCListStrings(locations.locations, locations.total_locations)
- C.free(unsafe.Pointer(locations))
-}
-
-func getTransitionSecureLocations(data interface{}) *C.serverLocations {
- if locations, ok := data.([]string); ok {
- returnedStruct := (*C.serverLocations)(C.malloc(C.size_t(unsafe.Sizeof(C.servers{}))))
- returnedStruct.total_locations, returnedStruct.locations = getCPtrListStrings(locations)
- return returnedStruct
- }
- return nil
-}
-
-func getTransitionProfiles(data interface{}) *C.serverProfiles {
- if profiles, ok := data.(*server.ProfileInfo); ok {
- return getCPtrProfiles(profiles)
- }
- return nil
-}
-
-func getTransitionServer(state *client.Client, data interface{}) *C.server {
- if server, ok := data.(server.Server); ok {
- base, baseErr := server.Base()
- if baseErr != nil {
- // TODO: LOG
- return nil
- }
- return getCPtrServer(state, base)
- }
- return nil
-}