diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2023-03-20 13:02:18 +0100 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2023-09-25 09:43:37 +0200 |
| commit | 7e7472c30e09eed15424494547729c1f93bc924e (patch) | |
| tree | 92ed648c034fd5618c07b51a4ce66eefcad53e54 /exports/exports.go | |
| parent | 56d7f4d6fc7f4e3a0e31d38f9b73de0375e90349 (diff) | |
Exports: Implement initial V2 API
The main change is that we now use JSON from types listed at the
`types` package
Diffstat (limited to 'exports/exports.go')
| -rw-r--r-- | exports/exports.go | 643 |
1 files changed, 236 insertions, 407 deletions
diff --git a/exports/exports.go b/exports/exports.go index c09e980..d8fd6ea 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -6,114 +6,87 @@ package main #include "server.h" typedef long long int (*ReadRxBytes)(); -typedef struct token { - const char* access; - const char* refresh; - unsigned long long int expired; -} token; -typedef void (*UpdateToken)(const char* name, server* srv, token* tok); - -typedef struct configData { - const char* config; - const char* config_type; - token* tokens; -} configData; - -typedef int (*PythonCB)(const char* name, int oldstate, int newstate, void* data); +typedef int (*StateCB)(int oldstate, int newstate, void* data); static long long int get_read_rx_bytes(ReadRxBytes read) { return read(); } - -static void update_token(UpdateToken func, const char* name, server* srv, token* tok) -{ - func(name, srv, tok); -} - -static int call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data) +static int call_callback(StateCB callback, int oldstate, int newstate, void* data) { - return callback(name, oldstate, newstate, data); + return callback(oldstate, newstate, data); } */ import "C" import ( - "time" + "encoding/json" "unsafe" - "github.com/eduvpn/eduvpn-common/internal/log" - "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/internal/server" "github.com/go-errors/errors" "github.com/eduvpn/eduvpn-common/client" + "github.com/eduvpn/eduvpn-common/types" ) -var PStateCallbacks map[string]C.PythonCB +var ( + PStateCallback C.StateCB + VPNState *client.Client +) -var VPNStates map[string]*client.Client +func getTokens(tokens *C.char) (t types.Tokens, err error) { + err = json.Unmarshal([]byte(C.GoString(tokens)), &t) + return t, err +} -func GetStateData( - state *client.Client, - stateID client.FSMStateID, - data interface{}, -) unsafe.Pointer { - switch stateID { - case client.StateNoServer: - return (unsafe.Pointer)(getTransitionDataServers(state, data)) - case client.StateOAuthStarted: - if converted, ok := data.(string); ok { - return (unsafe.Pointer)(C.CString(converted)) - } - case client.StateAskLocation: - return (unsafe.Pointer)(getTransitionSecureLocations(data)) - case client.StateAskProfile: - return (unsafe.Pointer)(getTransitionProfiles(data)) - case client.StateDisconnected: - return (unsafe.Pointer)(getTransitionServer(state, data)) - case client.StateDisconnecting: - return (unsafe.Pointer)(getTransitionServer(state, data)) - case client.StateConnecting: - return (unsafe.Pointer)(getTransitionServer(state, data)) - case client.StateConnected: - return (unsafe.Pointer)(getTransitionServer(state, data)) - default: - return nil +func getCError(err error) *C.char { + if err == nil { + return C.CString("") } - return nil + return C.CString(err.Error()) +} + +func getReturnData(data interface{}) (string, error) { + // If it is already a string return directly + if x, ok := data.(string); ok { + return x, nil + } + + // Otherwise use JSON + b, err := json.Marshal(data) + if err != nil { + return "", err + } + return string(b), nil } func StateCallback( state *client.Client, - name string, oldState client.FSMStateID, newState client.FSMStateID, data interface{}, ) bool { - PStateCallback, exists := PStateCallbacks[name] - if !exists || PStateCallback == nil { + if PStateCallback == nil { return false } - nameC := C.CString(name) oldStateC := C.int(oldState) newStateC := C.int(newState) - dataC := GetStateData(state, newState, data) - handled := C.call_callback(PStateCallback, nameC, oldStateC, newStateC, dataC) - C.free(unsafe.Pointer(nameC)) - // data_c gets freed by the wrapper + d, err := getReturnData(data) + if err != nil { + return false + } + dataC := C.CString(d) + handled := C.call_callback(PStateCallback, oldStateC, newStateC, unsafe.Pointer(dataC)) + FreeString(dataC) return handled == C.int(1) } -func GetVPNState(name string) (*client.Client, error) { - state, exists := VPNStates[name] - - if !exists || state == nil { - return nil, errors.Errorf("state with name %s not found", name) +func getVPNState() (*client.Client, error) { + if VPNState == nil { + return nil, errors.New("No state available, did you register the client?") } - - return state, nil + return VPNState, nil } //export Register @@ -121,39 +94,44 @@ func Register( name *C.char, version *C.char, configDirectory *C.char, - language *C.char, - stateCallback C.PythonCB, + stateCallback C.StateCB, debug C.int, -) *C.error { - nameStr := C.GoString(name) - versionStr := C.GoString(version) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - state = &client.Client{} - } - if VPNStates == nil { - VPNStates = make(map[string]*client.Client) - } - if PStateCallbacks == nil { - PStateCallbacks = make(map[string]C.PythonCB) - } - VPNStates[nameStr] = state - PStateCallbacks[nameStr] = stateCallback +) *C.char { + state, stateErr := getVPNState() + if stateErr == nil { + return getCError(errors.New("failed to register, a VPN state is already present")) + } + state = &client.Client{} + VPNState = state + PStateCallback = stateCallback registerErr := state.Register( - nameStr, - versionStr, + C.GoString(name), + C.GoString(version), C.GoString(configDirectory), - C.GoString(language), func(old client.FSMStateID, new client.FSMStateID, data interface{}) bool { - return StateCallback(state, nameStr, old, new, data) + return StateCallback(state, old, new, data) }, - debug != 0, + debug == 1, ) - if registerErr != nil { - delete(VPNStates, nameStr) + return getCError(registerErr) +} + +//export ExpiryTimes +func ExpiryTimes() (*C.char, *C.char) { + state, stateErr := getVPNState() + if stateErr != nil { + return nil, getCError(stateErr) + } + exp, err := state.ExpiryTimes() + if err != nil { + return nil, getCError(err) + } + ret, err := getReturnData(exp) + if err != nil { + return nil, getCError(err) } - return getError(registerErr) + return C.CString(ret), nil } //export SetTokenUpdater @@ -179,362 +157,216 @@ func SetTokenUpdater(name *C.char, updater C.UpdateToken) *C.error { } //export Deregister -func Deregister(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func Deregister() *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } state.Deregister() + VPNState = nil return nil } -func getError(err error) *C.error { - if err == nil { - return nil - } - errorStruct := (*C.error)( - C.malloc(C.size_t(unsafe.Sizeof(C.error{}))), - ) - if err1, ok := err.(*errors.Error); ok { - if err1 == nil { - errorStruct.traceback = C.CString("N/A") - errorStruct.cause = C.CString("unknown error") - return errorStruct - } - errorStruct.traceback = C.CString(err1.ErrorStack()) - if err1.Err == nil { - errorStruct.cause = C.CString(err1.Error()) - } else { - errorStruct.cause = C.CString(err1.Err.Error()) - } - } else { - errorStruct.traceback = C.CString("N/A") - errorStruct.cause = C.CString(err.Error()) - } - return errorStruct -} - -//export FreeError -func FreeError(err *C.error) { - C.free(unsafe.Pointer(err.traceback)) - C.free(unsafe.Pointer(err.cause)) - C.free(unsafe.Pointer(err)) -} - //export CancelOAuth -func CancelOAuth(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func CancelOAuth() *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } cancelErr := state.CancelOAuth() - return getError(cancelErr) + return getCError(cancelErr) } -//export RemoveSecureInternet -func RemoveSecureInternet(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export AddServer +func AddServer(_type *C.char, id *C.char) *C.char { + // TODO: type + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) - } - removeErr := state.RemoveSecureInternet() - return getError(removeErr) -} - -//export AddInstituteAccess -func AddInstituteAccess(name *C.char, url *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) + } + t := C.GoString(_type) + var err error + switch t { + case "institute_access": + err = state.AddInstituteServer(C.GoString(id)) + case "secure_internet": + err = state.AddSecureInternetHomeServer(C.GoString(id)) + case "custom_server": + err = state.AddCustomServer(C.GoString(id)) + default: + err = errors.Errorf("invalid type: %v", t) } - // FIXME: Return server result - _, addErr := state.AddInstituteServer(C.GoString(url)) - return getError(addErr) + return getCError(err) } -//export AddSecureInternetHomeServer -func AddSecureInternetHomeServer(name *C.char, orgID *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export RemoveServer +func RemoveServer(_type *C.char, id *C.char) *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) - } - // FIXME: Return server result - _, addErr := state.AddSecureInternetHomeServer(C.GoString(orgID)) - return getError(addErr) -} - -//export AddCustomServer -func AddCustomServer(name *C.char, url *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) + } + t := C.GoString(_type) + var err error + switch t { + case "institute_access": + err = state.RemoveInstituteAccess(C.GoString(id)) + case "secure_internet": + err = state.RemoveSecureInternet() + case "custom_server": + err = state.RemoveCustomServer(C.GoString(id)) + default: + err = errors.Errorf("invalid type: %v", t) } - // FIXME: Return server result - _, addErr := state.AddCustomServer(C.GoString(url)) - return getError(addErr) + return getCError(err) } -//export RemoveInstituteAccess -func RemoveInstituteAccess(name *C.char, url *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export CurrentServer +func CurrentServer() (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return nil, getCError(stateErr) } - removeErr := state.RemoveInstituteAccess(C.GoString(url)) - return getError(removeErr) -} - -//export RemoveCustomServer -func RemoveCustomServer(name *C.char, url *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + srv, err := state.CurrentServer() + if err != nil { + return nil, getCError(err) } - removeErr := state.RemoveCustomServer(C.GoString(url)) - return getError(removeErr) -} - -func cToken(t oauth.Token) *C.token { - cTok := (*C.token)(C.malloc(C.size_t(unsafe.Sizeof(C.token{})))) - cTok.access = C.CString(t.Access) - cTok.refresh = C.CString(t.Refresh) - cTok.expired = C.ulonglong(t.ExpiredTimestamp.Unix()) - return cTok -} - -func cConfig(config *client.ConfigData) *C.configData { - // No config so return nil pointer - if config == nil { - return nil + ret, err := getReturnData(srv) + if err != nil { + return nil, getCError(err) } - cConf := (*C.configData)(C.malloc(C.size_t(unsafe.Sizeof(C.configData{})))) - cConf.config = C.CString(config.Config) - cConf.config_type = C.CString(config.Type) - cConf.tokens = cToken(config.Tokens) - return cConf + return C.CString(ret), nil } -//export FreeTokens -func FreeTokens(tokens *C.token) { - C.free(unsafe.Pointer(tokens.access)) - C.free(unsafe.Pointer(tokens.refresh)) - C.free(unsafe.Pointer(tokens)) -} - -//export FreeConfig -func FreeConfig(config *C.configData) { - C.free(unsafe.Pointer(config.config)) - C.free(unsafe.Pointer(config.config_type)) - C.free(unsafe.Pointer(config.tokens.access)) - C.free(unsafe.Pointer(config.tokens.refresh)) - C.free(unsafe.Pointer(config.tokens)) - C.free(unsafe.Pointer(config)) -} - -//export GetConfigSecureInternet -func GetConfigSecureInternet( - name *C.char, - orgID *C.char, - preferTCP C.int, - prevTokens C.token, -) (*C.configData, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export ServerList +func ServerList() (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return nil, getError(stateErr) - } - preferTCPBool := preferTCP == 1 - t := oauth.Token{ - Access: C.GoString(prevTokens.access), - Refresh: C.GoString(prevTokens.refresh), - ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0), + return nil, getCError(stateErr) } - cfg, configErr := state.GetConfigSecureInternet(C.GoString(orgID), preferTCPBool, t) - return cConfig(cfg), getError(configErr) -} - -//export GetConfigInstituteAccess -func GetConfigInstituteAccess( - name *C.char, - url *C.char, - preferTCP C.int, - prevTokens C.token, -) (*C.configData, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return nil, getError(stateErr) + list, err := state.ServerList() + if err != nil { + return nil, getCError(err) } - preferTCPBool := preferTCP == 1 - t := oauth.Token{ - Access: C.GoString(prevTokens.access), - Refresh: C.GoString(prevTokens.refresh), - ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0), + ret, err := getReturnData(list) + if err != nil { + return nil, getCError(err) } - cfg, configErr := state.GetConfigInstituteAccess(C.GoString(url), preferTCPBool, t) - return cConfig(cfg), getError(configErr) + return C.CString(ret), nil } -//export GetConfigCustomServer -func GetConfigCustomServer( - name *C.char, - url *C.char, - preferTCP C.int, - prevTokens C.token, -) (*C.configData, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export GetConfig +func GetConfig(_type *C.char, id *C.char, pTCP C.int, tokens *C.char) (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return nil, getError(stateErr) + return nil, getCError(stateErr) + } + preferTCPBool := pTCP == 1 + tok, err := getTokens(tokens) + if err != nil { + return nil, getCError(err) + } + t := C.GoString(_type) + var cfg *types.Configuration + switch t { + case "institute_access": + cfg, err = state.GetConfigInstituteAccess(C.GoString(id), preferTCPBool, tok) + case "secure_internet": + cfg, err = state.GetConfigSecureInternet(C.GoString(id), preferTCPBool, tok) + case "custom_server": + cfg, err = state.GetConfigCustomServer(C.GoString(id), preferTCPBool, tok) + default: + err = errors.Errorf("invalid type: %v", t) } - preferTCPBool := preferTCP == 1 - t := oauth.Token{ - Access: C.GoString(prevTokens.access), - Refresh: C.GoString(prevTokens.refresh), - ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0), + if cfg != nil && err == nil { + d, err := getReturnData(cfg) + if err == nil { + return C.CString(d), nil + } } - cfg, configErr := state.GetConfigCustomServer(C.GoString(url), preferTCPBool, t) - return cConfig(cfg), getError(configErr) + return nil, getCError(err) } //export SetProfileID -func SetProfileID(name *C.char, data *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func SetProfileID(data *C.char) *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } profileErr := state.SetProfileID(C.GoString(data)) - return getError(profileErr) -} - -//export ChangeSecureLocation -func ChangeSecureLocation(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) - } - locationErr := state.ChangeSecureLocation() - return getError(locationErr) + return getCError(profileErr) } //export SetSecureLocation -func SetSecureLocation(name *C.char, data *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func SetSecureLocation(data *C.char) *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } locationErr := state.SetSecureLocation(C.GoString(data)) - return getError(locationErr) -} - -//export GoBack -func GoBack(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) - } - goBackErr := state.GoBack() - return getError(goBackErr) + return getCError(locationErr) } -//export SetSearchServer -func SetSearchServer(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export DiscoServers +func DiscoServers() (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return nil, getCError(stateErr) } - setSearchErr := state.SetSearchServer() - return getError(setSearchErr) -} - -//export Cleanup -func Cleanup(name *C.char, prevTokens C.token) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + servers, err := state.DiscoServers() + if servers == nil && err != nil { + return nil, getCError(err) } - t := oauth.Token{ - Access: C.GoString(prevTokens.access), - Refresh: C.GoString(prevTokens.refresh), - ExpiredTimestamp: time.Unix(int64(prevTokens.expired), 0), + s, reterr := getReturnData(servers) + if reterr != nil { + return nil, getCError(reterr) } - err := state.Cleanup(t) - return getError(err) + return C.CString(s), getCError(err) } -//export SetDisconnected -func SetDisconnected(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export DiscoOrganizations +func DiscoOrganizations() (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return nil, getCError(stateErr) } - setDisconnectedErr := state.SetDisconnected() - return getError(setDisconnectedErr) -} - -//export SetDisconnecting -func SetDisconnecting(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + orgs, err := state.DiscoOrganizations() + if orgs == nil && err != nil { + return nil, getCError(err) } - setDisconnectingErr := state.SetDisconnecting() - return getError(setDisconnectingErr) -} - -//export SetConnecting -func SetConnecting(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) - if stateErr != nil { - return getError(stateErr) + s, reterr := getReturnData(orgs) + if reterr != nil { + return nil, getCError(reterr) } - setConnectingErr := state.SetConnecting() - return getError(setConnectingErr) + return C.CString(s), getCError(err) } -//export SetConnected -func SetConnected(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export Cleanup +func Cleanup(prevTokens *C.char) *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } - setConnectedErr := state.SetConnected() - return getError(setConnectedErr) + t, err := getTokens(prevTokens) + if err != nil { + return getCError(err) + } + err = state.Cleanup(t) + return getCError(err) } //export RenewSession -func RenewSession(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func RenewSession() *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } renewSessionErr := state.RenewSession() - return getError(renewSessionErr) + return getCError(renewSessionErr) } //export ShouldRenewButton -func ShouldRenewButton(name *C.char) C.int { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func ShouldRenewButton() C.int { + state, stateErr := getVPNState() if stateErr != nil { return C.int(0) } @@ -545,47 +377,45 @@ func ShouldRenewButton(name *C.char) C.int { return C.int(0) } -//export InFSMState -func InFSMState(name *C.char, checkState C.int) C.int { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export SetSupportWireguard +func SetSupportWireguard(support C.int) *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return C.int(0) + return getCError(stateErr) } - inStateBool := state.InFSMState(client.FSMStateID(checkState)) - if inStateBool { - return C.int(1) - } - return C.int(0) + state.SupportsWireguard = support == 1 + return nil } -//export SetSupportWireguard -func SetSupportWireguard(name *C.char, support C.int) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +//export SecureLocationList +func SecureLocationList() (*C.char, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return nil, getCError(stateErr) } - state.SupportsWireguard = support == 1 - return nil + locs := state.Discovery.SecureLocationList() + l, err := getReturnData(locs) + if err != nil { + return nil, getCError(err) + } + return C.CString(l), nil } //export StartFailover -func StartFailover(name *C.char, gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int, *C.error) { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func StartFailover(gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int, *C.char) { + state, stateErr := getVPNState() if stateErr != nil { - return C.int(0), getError(stateErr) + return C.int(0), getCError(stateErr) } dropped, droppedErr := state.StartFailover(C.GoString(gateway), int(mtu), func() (int64, error) { rxBytes := int64(C.get_read_rx_bytes(readRxBytes)) - if rxBytes == -1 { + if rxBytes < 0 { return 0, errors.New("client gave an invalid rx bytes value") } return rxBytes, nil }) if droppedErr != nil { - return C.int(0), getError(droppedErr) + return C.int(0), getCError(droppedErr) } droppedC := C.int(0) if dropped { @@ -595,15 +425,14 @@ func StartFailover(name *C.char, gateway *C.char, mtu C.int, readRxBytes C.ReadR } //export CancelFailover -func CancelFailover(name *C.char) *C.error { - nameStr := C.GoString(name) - state, stateErr := GetVPNState(nameStr) +func CancelFailover() *C.char { + state, stateErr := getVPNState() if stateErr != nil { - return getError(stateErr) + return getCError(stateErr) } cancelErr := state.CancelFailover() if cancelErr != nil { - return getError(cancelErr) + return getCError(cancelErr) } return nil } |
