diff options
| -rw-r--r-- | exports/exports.go | 196 |
1 files changed, 103 insertions, 93 deletions
diff --git a/exports/exports.go b/exports/exports.go index b0c13e0..af56a8a 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -1,6 +1,7 @@ package main /* +#include <stdint.h> #include <stdlib.h> #include "error.h" #include "server.h" @@ -21,22 +22,20 @@ static int call_callback(StateCB callback, int oldstate, int newstate, void* dat import "C" import ( + "context" "encoding/json" + "runtime/cgo" "unsafe" "github.com/go-errors/errors" "github.com/eduvpn/eduvpn-common/client" + "github.com/eduvpn/eduvpn-common/types/cookie" srvtypes "github.com/eduvpn/eduvpn-common/types/server" ) var VPNState *client.Client -func getTokens(tokens *C.char) (t srvtypes.Tokens, err error) { - err = json.Unmarshal([]byte(C.GoString(tokens)), &t) - return t, err -} - func getCError(err error) *C.char { if err == nil { return nil @@ -174,35 +173,18 @@ func Deregister() *C.char { return nil } -//export CancelOAuth -func CancelOAuth() *C.char { - state, stateErr := getVPNState() - if stateErr != nil { - return getCError(stateErr) - } - cancelErr := state.CancelOAuth() - return getCError(cancelErr) -} - //export AddServer -func AddServer(_type C.int, id *C.char) *C.char { +func AddServer(c C.uintptr_t, _type C.int, id *C.char, ni C.int) *C.char { // TODO: type state, stateErr := getVPNState() if stateErr != nil { return getCError(stateErr) } - t := int(_type) - var err error - switch t { - case int(srvtypes.TypeInstituteAccess): - err = state.AddInstituteServer(C.GoString(id)) - case int(srvtypes.TypeSecureInternet): - err = state.AddSecureInternetHomeServer(C.GoString(id)) - case int(srvtypes.TypeCustom): - err = state.AddCustomServer(C.GoString(id)) - default: - err = errors.Errorf("invalid type: %v", t) + v, err := getCookie(c) + if err != nil { + return getCError(err) } + err = state.AddServer(v, C.GoString(id), srvtypes.Type(_type), ni != 0) return getCError(err) } @@ -212,18 +194,7 @@ func RemoveServer(_type C.int, id *C.char) *C.char { if stateErr != nil { return getCError(stateErr) } - t := int(_type) - var err error - switch t { - case int(srvtypes.TypeInstituteAccess): - err = state.RemoveInstituteAccess(C.GoString(id)) - case int(srvtypes.TypeSecureInternet): - err = state.RemoveSecureInternet() - case int(srvtypes.TypeCustom): - err = state.RemoveCustomServer(C.GoString(id)) - default: - err = errors.Errorf("invalid type: %v", t) - } + err := state.RemoveServer(C.GoString(id), srvtypes.Type(_type)) return getCError(err) } @@ -262,28 +233,17 @@ func ServerList() (*C.char, *C.char) { } //export GetConfig -func GetConfig(_type C.int, id *C.char, pTCP C.int, tokens *C.char) (*C.char, *C.char) { +func GetConfig(c C.uintptr_t, _type C.int, id *C.char, pTCP C.int) (*C.char, *C.char) { state, stateErr := getVPNState() if stateErr != nil { return nil, getCError(stateErr) } - preferTCPBool := pTCP != 0 - tok, err := getTokens(tokens) + ck, err := getCookie(c) if err != nil { return nil, getCError(err) } - t := int(_type) - var cfg *srvtypes.Configuration - switch t { - case int(srvtypes.TypeInstituteAccess): - cfg, err = state.GetConfigInstituteAccess(C.GoString(id), preferTCPBool, tok) - case int(srvtypes.TypeSecureInternet): - cfg, err = state.GetConfigSecureInternet(C.GoString(id), preferTCPBool, tok) - case int(srvtypes.TypeCustom): - cfg, err = state.GetConfigCustomServer(C.GoString(id), preferTCPBool, tok) - default: - err = errors.Errorf("invalid type: %v", t) - } + preferTCPBool := pTCP != 0 + cfg, err := state.GetConfig(ck, C.GoString(id), srvtypes.Type(_type), preferTCPBool) if cfg != nil && err == nil { d, err := getReturnData(cfg) if err == nil { @@ -304,22 +264,30 @@ func SetProfileID(data *C.char) *C.char { } //export SetSecureLocation -func SetSecureLocation(data *C.char) *C.char { +func SetSecureLocation(c C.uintptr_t, data *C.char) *C.char { state, stateErr := getVPNState() if stateErr != nil { return getCError(stateErr) } - locationErr := state.SetSecureLocation(C.GoString(data)) + ck, err := getCookie(c) + if err != nil { + return getCError(err) + } + locationErr := state.SetSecureLocation(ck, C.GoString(data)) return getCError(locationErr) } //export DiscoServers -func DiscoServers() (*C.char, *C.char) { +func DiscoServers(c C.uintptr_t) (*C.char, *C.char) { state, stateErr := getVPNState() if stateErr != nil { return nil, getCError(stateErr) } - servers, err := state.DiscoServers() + ck, err := getCookie(c) + if err != nil { + return nil, getCError(err) + } + servers, err := state.DiscoServers(ck) if servers == nil && err != nil { return nil, getCError(err) } @@ -331,12 +299,16 @@ func DiscoServers() (*C.char, *C.char) { } //export DiscoOrganizations -func DiscoOrganizations() (*C.char, *C.char) { +func DiscoOrganizations(c C.uintptr_t) (*C.char, *C.char) { state, stateErr := getVPNState() if stateErr != nil { return nil, getCError(stateErr) } - orgs, err := state.DiscoOrganizations() + ck, err := getCookie(c) + if err != nil { + return nil, getCError(err) + } + orgs, err := state.DiscoOrganizations(ck) if orgs == nil && err != nil { return nil, getCError(err) } @@ -348,26 +320,30 @@ func DiscoOrganizations() (*C.char, *C.char) { } //export Cleanup -func Cleanup(prevTokens *C.char) *C.char { +func Cleanup(c C.uintptr_t) *C.char { state, stateErr := getVPNState() if stateErr != nil { return getCError(stateErr) } - t, err := getTokens(prevTokens) + ck, err := getCookie(c) if err != nil { return getCError(err) } - err = state.Cleanup(t) + err = state.Cleanup(ck) return getCError(err) } //export RenewSession -func RenewSession() *C.char { +func RenewSession(c C.uintptr_t) *C.char { state, stateErr := getVPNState() if stateErr != nil { return getCError(stateErr) } - renewSessionErr := state.RenewSession() + ck, err := getCookie(c) + if err != nil { + return getCError(err) + } + renewSessionErr := state.RenewSession(ck) return getCError(renewSessionErr) } @@ -381,27 +357,17 @@ func SetSupportWireguard(support C.int) *C.char { return nil } -//export SecureLocationList -func SecureLocationList() (*C.char, *C.char) { - state, stateErr := getVPNState() - if stateErr != nil { - return nil, getCError(stateErr) - } - locs := state.Discovery.SecureLocationList() - l, err := getReturnData(locs) - if err != nil { - return nil, getCError(err) - } - return C.CString(l), nil -} - //export StartFailover -func StartFailover(gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int, *C.char) { +func StartFailover(c C.uintptr_t, gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int, *C.char) { state, stateErr := getVPNState() if stateErr != nil { return C.int(0), getCError(stateErr) } - dropped, droppedErr := state.StartFailover(C.GoString(gateway), int(mtu), func() (int64, error) { + ck, err := getCookie(c) + if err != nil { + return C.int(0), getCError(err) + } + dropped, droppedErr := state.StartFailover(ck, C.GoString(gateway), int(mtu), func() (int64, error) { rxBytes := int64(C.get_read_rx_bytes(readRxBytes)) if rxBytes < 0 { return 0, errors.New("client gave an invalid rx bytes value") @@ -418,22 +384,66 @@ func StartFailover(gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int return droppedC, nil } -//export CancelFailover -func CancelFailover() *C.char { - state, stateErr := getVPNState() - if stateErr != nil { - return getCError(stateErr) +//export FreeString +func FreeString(addr *C.char) { + C.free(unsafe.Pointer(addr)) +} + +func getCookie(c C.uintptr_t) (*cookie.Cookie, error) { + if c == 0 { + return nil, errors.New("cookie is nil") } - cancelErr := state.CancelFailover() - if cancelErr != nil { - return getCError(cancelErr) + h := cgo.Handle(c) + v, ok := h.Value().(*cookie.Cookie) + if !ok { + return nil, errors.New("value is not a cookie") } - return nil + // the cookie itself has a reference to the handle + // such that we can return the same exact handle in callbacks + // TODO: On first glance this might not make any sense, find a better way + v.H = h + return v, nil } -//export FreeString -func FreeString(addr *C.char) { - C.free(unsafe.Pointer(addr)) +//export CookieNew +func CookieNew() C.uintptr_t { + c := cookie.NewWithContext(context.Background()) + return C.uintptr_t(cgo.NewHandle(&c)) +} + +//export CookieReply +func CookieReply(c C.uintptr_t, data *C.char) *C.char { + v, err := getCookie(c) + if err != nil { + return getCError(err) + } + err = v.Send(C.GoString(data)) + return getCError(err) +} + +//export CookieDelete +func CookieDelete(c C.uintptr_t) *C.char { + v, err := getCookie(c) + if err != nil { + return getCError(err) + } + // cancel the cookie and then delete the handle + err = v.Cancel() + cgo.Handle(c).Delete() + return getCError(err) +} + +//export CookieCancel +func CookieCancel(c C.uintptr_t) *C.char { + v, err := getCookie(c) + if err != nil { + return getCError(err) + } + err = v.Cancel() + if err != nil { + return getCError(err) + } + return nil } // Not used in library, but needed to compile. |
