summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--exports/exports.go196
1 files changed, 103 insertions, 93 deletions
diff --git a/exports/exports.go b/exports/exports.go
index b0c13e0..af56a8a 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -1,6 +1,7 @@
package main
/*
+#include <stdint.h>
#include <stdlib.h>
#include "error.h"
#include "server.h"
@@ -21,22 +22,20 @@ static int call_callback(StateCB callback, int oldstate, int newstate, void* dat
import "C"
import (
+ "context"
"encoding/json"
+ "runtime/cgo"
"unsafe"
"github.com/go-errors/errors"
"github.com/eduvpn/eduvpn-common/client"
+ "github.com/eduvpn/eduvpn-common/types/cookie"
srvtypes "github.com/eduvpn/eduvpn-common/types/server"
)
var VPNState *client.Client
-func getTokens(tokens *C.char) (t srvtypes.Tokens, err error) {
- err = json.Unmarshal([]byte(C.GoString(tokens)), &t)
- return t, err
-}
-
func getCError(err error) *C.char {
if err == nil {
return nil
@@ -174,35 +173,18 @@ func Deregister() *C.char {
return nil
}
-//export CancelOAuth
-func CancelOAuth() *C.char {
- state, stateErr := getVPNState()
- if stateErr != nil {
- return getCError(stateErr)
- }
- cancelErr := state.CancelOAuth()
- return getCError(cancelErr)
-}
-
//export AddServer
-func AddServer(_type C.int, id *C.char) *C.char {
+func AddServer(c C.uintptr_t, _type C.int, id *C.char, ni C.int) *C.char {
// TODO: type
state, stateErr := getVPNState()
if stateErr != nil {
return getCError(stateErr)
}
- t := int(_type)
- var err error
- switch t {
- case int(srvtypes.TypeInstituteAccess):
- err = state.AddInstituteServer(C.GoString(id))
- case int(srvtypes.TypeSecureInternet):
- err = state.AddSecureInternetHomeServer(C.GoString(id))
- case int(srvtypes.TypeCustom):
- err = state.AddCustomServer(C.GoString(id))
- default:
- err = errors.Errorf("invalid type: %v", t)
+ v, err := getCookie(c)
+ if err != nil {
+ return getCError(err)
}
+ err = state.AddServer(v, C.GoString(id), srvtypes.Type(_type), ni != 0)
return getCError(err)
}
@@ -212,18 +194,7 @@ func RemoveServer(_type C.int, id *C.char) *C.char {
if stateErr != nil {
return getCError(stateErr)
}
- t := int(_type)
- var err error
- switch t {
- case int(srvtypes.TypeInstituteAccess):
- err = state.RemoveInstituteAccess(C.GoString(id))
- case int(srvtypes.TypeSecureInternet):
- err = state.RemoveSecureInternet()
- case int(srvtypes.TypeCustom):
- err = state.RemoveCustomServer(C.GoString(id))
- default:
- err = errors.Errorf("invalid type: %v", t)
- }
+ err := state.RemoveServer(C.GoString(id), srvtypes.Type(_type))
return getCError(err)
}
@@ -262,28 +233,17 @@ func ServerList() (*C.char, *C.char) {
}
//export GetConfig
-func GetConfig(_type C.int, id *C.char, pTCP C.int, tokens *C.char) (*C.char, *C.char) {
+func GetConfig(c C.uintptr_t, _type C.int, id *C.char, pTCP C.int) (*C.char, *C.char) {
state, stateErr := getVPNState()
if stateErr != nil {
return nil, getCError(stateErr)
}
- preferTCPBool := pTCP != 0
- tok, err := getTokens(tokens)
+ ck, err := getCookie(c)
if err != nil {
return nil, getCError(err)
}
- t := int(_type)
- var cfg *srvtypes.Configuration
- switch t {
- case int(srvtypes.TypeInstituteAccess):
- cfg, err = state.GetConfigInstituteAccess(C.GoString(id), preferTCPBool, tok)
- case int(srvtypes.TypeSecureInternet):
- cfg, err = state.GetConfigSecureInternet(C.GoString(id), preferTCPBool, tok)
- case int(srvtypes.TypeCustom):
- cfg, err = state.GetConfigCustomServer(C.GoString(id), preferTCPBool, tok)
- default:
- err = errors.Errorf("invalid type: %v", t)
- }
+ preferTCPBool := pTCP != 0
+ cfg, err := state.GetConfig(ck, C.GoString(id), srvtypes.Type(_type), preferTCPBool)
if cfg != nil && err == nil {
d, err := getReturnData(cfg)
if err == nil {
@@ -304,22 +264,30 @@ func SetProfileID(data *C.char) *C.char {
}
//export SetSecureLocation
-func SetSecureLocation(data *C.char) *C.char {
+func SetSecureLocation(c C.uintptr_t, data *C.char) *C.char {
state, stateErr := getVPNState()
if stateErr != nil {
return getCError(stateErr)
}
- locationErr := state.SetSecureLocation(C.GoString(data))
+ ck, err := getCookie(c)
+ if err != nil {
+ return getCError(err)
+ }
+ locationErr := state.SetSecureLocation(ck, C.GoString(data))
return getCError(locationErr)
}
//export DiscoServers
-func DiscoServers() (*C.char, *C.char) {
+func DiscoServers(c C.uintptr_t) (*C.char, *C.char) {
state, stateErr := getVPNState()
if stateErr != nil {
return nil, getCError(stateErr)
}
- servers, err := state.DiscoServers()
+ ck, err := getCookie(c)
+ if err != nil {
+ return nil, getCError(err)
+ }
+ servers, err := state.DiscoServers(ck)
if servers == nil && err != nil {
return nil, getCError(err)
}
@@ -331,12 +299,16 @@ func DiscoServers() (*C.char, *C.char) {
}
//export DiscoOrganizations
-func DiscoOrganizations() (*C.char, *C.char) {
+func DiscoOrganizations(c C.uintptr_t) (*C.char, *C.char) {
state, stateErr := getVPNState()
if stateErr != nil {
return nil, getCError(stateErr)
}
- orgs, err := state.DiscoOrganizations()
+ ck, err := getCookie(c)
+ if err != nil {
+ return nil, getCError(err)
+ }
+ orgs, err := state.DiscoOrganizations(ck)
if orgs == nil && err != nil {
return nil, getCError(err)
}
@@ -348,26 +320,30 @@ func DiscoOrganizations() (*C.char, *C.char) {
}
//export Cleanup
-func Cleanup(prevTokens *C.char) *C.char {
+func Cleanup(c C.uintptr_t) *C.char {
state, stateErr := getVPNState()
if stateErr != nil {
return getCError(stateErr)
}
- t, err := getTokens(prevTokens)
+ ck, err := getCookie(c)
if err != nil {
return getCError(err)
}
- err = state.Cleanup(t)
+ err = state.Cleanup(ck)
return getCError(err)
}
//export RenewSession
-func RenewSession() *C.char {
+func RenewSession(c C.uintptr_t) *C.char {
state, stateErr := getVPNState()
if stateErr != nil {
return getCError(stateErr)
}
- renewSessionErr := state.RenewSession()
+ ck, err := getCookie(c)
+ if err != nil {
+ return getCError(err)
+ }
+ renewSessionErr := state.RenewSession(ck)
return getCError(renewSessionErr)
}
@@ -381,27 +357,17 @@ func SetSupportWireguard(support C.int) *C.char {
return nil
}
-//export SecureLocationList
-func SecureLocationList() (*C.char, *C.char) {
- state, stateErr := getVPNState()
- if stateErr != nil {
- return nil, getCError(stateErr)
- }
- locs := state.Discovery.SecureLocationList()
- l, err := getReturnData(locs)
- if err != nil {
- return nil, getCError(err)
- }
- return C.CString(l), nil
-}
-
//export StartFailover
-func StartFailover(gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int, *C.char) {
+func StartFailover(c C.uintptr_t, gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int, *C.char) {
state, stateErr := getVPNState()
if stateErr != nil {
return C.int(0), getCError(stateErr)
}
- dropped, droppedErr := state.StartFailover(C.GoString(gateway), int(mtu), func() (int64, error) {
+ ck, err := getCookie(c)
+ if err != nil {
+ return C.int(0), getCError(err)
+ }
+ dropped, droppedErr := state.StartFailover(ck, C.GoString(gateway), int(mtu), func() (int64, error) {
rxBytes := int64(C.get_read_rx_bytes(readRxBytes))
if rxBytes < 0 {
return 0, errors.New("client gave an invalid rx bytes value")
@@ -418,22 +384,66 @@ func StartFailover(gateway *C.char, mtu C.int, readRxBytes C.ReadRxBytes) (C.int
return droppedC, nil
}
-//export CancelFailover
-func CancelFailover() *C.char {
- state, stateErr := getVPNState()
- if stateErr != nil {
- return getCError(stateErr)
+//export FreeString
+func FreeString(addr *C.char) {
+ C.free(unsafe.Pointer(addr))
+}
+
+func getCookie(c C.uintptr_t) (*cookie.Cookie, error) {
+ if c == 0 {
+ return nil, errors.New("cookie is nil")
}
- cancelErr := state.CancelFailover()
- if cancelErr != nil {
- return getCError(cancelErr)
+ h := cgo.Handle(c)
+ v, ok := h.Value().(*cookie.Cookie)
+ if !ok {
+ return nil, errors.New("value is not a cookie")
}
- return nil
+ // the cookie itself has a reference to the handle
+ // such that we can return the same exact handle in callbacks
+ // TODO: On first glance this might not make any sense, find a better way
+ v.H = h
+ return v, nil
}
-//export FreeString
-func FreeString(addr *C.char) {
- C.free(unsafe.Pointer(addr))
+//export CookieNew
+func CookieNew() C.uintptr_t {
+ c := cookie.NewWithContext(context.Background())
+ return C.uintptr_t(cgo.NewHandle(&c))
+}
+
+//export CookieReply
+func CookieReply(c C.uintptr_t, data *C.char) *C.char {
+ v, err := getCookie(c)
+ if err != nil {
+ return getCError(err)
+ }
+ err = v.Send(C.GoString(data))
+ return getCError(err)
+}
+
+//export CookieDelete
+func CookieDelete(c C.uintptr_t) *C.char {
+ v, err := getCookie(c)
+ if err != nil {
+ return getCError(err)
+ }
+ // cancel the cookie and then delete the handle
+ err = v.Cancel()
+ cgo.Handle(c).Delete()
+ return getCError(err)
+}
+
+//export CookieCancel
+func CookieCancel(c C.uintptr_t) *C.char {
+ v, err := getCookie(c)
+ if err != nil {
+ return getCError(err)
+ }
+ err = v.Cancel()
+ if err != nil {
+ return getCError(err)
+ }
+ return nil
}
// Not used in library, but needed to compile.