summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--client/client.go108
-rw-r--r--exports/exports.go71
-rw-r--r--internal/server/server.go21
-rw-r--r--types/server/server.go3
-rw-r--r--wrappers/python/eduvpn_common/loader.py3
-rw-r--r--wrappers/python/eduvpn_common/main.py56
-rw-r--r--wrappers/python/eduvpn_common/types.py7
7 files changed, 228 insertions, 41 deletions
diff --git a/client/client.go b/client/client.go
index 813f6dc..28080b0 100644
--- a/client/client.go
+++ b/client/client.go
@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"strings"
+ "time"
"github.com/eduvpn/eduvpn-common/internal/config"
"github.com/eduvpn/eduvpn-common/internal/discovery"
@@ -109,7 +110,59 @@ type Client struct {
Debug bool `json:"-"`
// The Failover monitor for the current VPN connection
- Failover *failover.DroppedConMon
+ Failover *failover.DroppedConMon `json:"-"`
+
+ // TokenSetter sets the tokens in the client
+ TokenSetter func(srv srvtypes.Current, tok srvtypes.Tokens) `json:"-"`
+
+ // TokenGetter gets the tokens from the client
+ TokenGetter func(srv srvtypes.Current) *srvtypes.Tokens `json:"-"`
+}
+
+func (c *Client) updateTokens(srv server.Server) error {
+ if c.TokenGetter == nil {
+ return errors.New("no tokken getter defined")
+ }
+ pSrv, err := c.pubCurrentServer(srv)
+ if err != nil {
+ return err
+ }
+ // shouldn't happen
+ if pSrv == nil {
+ return errors.New("public server is nil when getting tokens")
+ }
+ tokens := c.TokenGetter(*pSrv)
+ if tokens == nil {
+ return errors.New("client returned nil for tokens")
+ }
+
+ server.UpdateTokens(srv, oauth.Token{
+ Access: tokens.Access,
+ Refresh: tokens.Refresh,
+ ExpiredTimestamp: time.Unix(tokens.Expires, 0),
+ })
+
+ return nil
+}
+
+func (c *Client) forwardTokens(srv server.Server) error {
+ if c.TokenSetter == nil {
+ return errors.New("no token setter defined")
+ }
+ pSrv, err := c.pubCurrentServer(srv)
+ if err != nil {
+ return err
+ }
+ if pSrv == nil {
+ return errors.New("public server is nil when updating tokens")
+ }
+ o := srv.OAuth()
+ if o == nil {
+ return errors.New("oauth was nil when forwarding tokens")
+ }
+ t := o.Token()
+ c.TokenSetter(*pSrv, t.Public())
+ return nil
}
// New creates a new client with the following parameters:
@@ -329,7 +382,18 @@ func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool)
// oauth
// TODO: This should be ck.Context()
// But needsrelogin needs a rewrite to support this properly
+
+ // first make sure we get the most up to date tokens from the client
+ err := c.updateTokens(srv)
+ if err != nil {
+ log.Logger.Debugf("failed to get tokens from client: %v", err)
+ }
if server.NeedsRelogin(context.Background(), srv) || forceauth {
+ // mark organizations as expired if the server is a secure internet server
+ b, berr := srv.Base()
+ if berr == nil && b.Type == srvtypes.TypeSecureInternet {
+ c.Discovery.MarkOrganizationsExpired()
+ }
err := c.loginCallback(ck, srv)
if err != nil {
return err
@@ -421,6 +485,10 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
case srvtypes.TypeSecureInternet:
dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier)
if err != nil {
+ // We mark the organizations as expired because we got an error
+ // Note that in the docs it states that it only should happen when the Org ID doesn't exist
+ // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found
+ c.Discovery.MarkOrganizationsExpired()
return err
}
srv, err = c.Servers.AddSecureInternet(ck.Context(), dOrg, dSrv)
@@ -442,7 +510,15 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
}
// callbacks
- return c.callbacks(ck, srv, false)
+ err = c.callbacks(ck, srv, false)
+ if err != nil {
+ return err
+ }
+ terr := c.forwardTokens(srv)
+ if terr != nil {
+ log.Logger.Debugf("failed to forward tokens after adding: %v", terr)
+ }
+ return nil
}
func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool) (cfg *srvtypes.Configuration, err error) {
@@ -529,6 +605,14 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.
cfg, err = c.config(ck, srv, pTCP, true)
}
+ // tokens might be updated, forward them
+ defer func() {
+ terr := c.forwardTokens(srv)
+ if terr != nil {
+ log.Logger.Debugf("failed to forward tokens after get config: %v", terr)
+ }
+ }()
+
// still an error, return nil with the error
if err != nil {
return nil, err
@@ -700,15 +784,18 @@ func (c *Client) Cleanup(ck *cookie.Cookie) (err error) {
if err != nil {
return err
}
- // TODO: Support cookie context here
- // if server.NeedsRelogin(context.Background(), srv) {
- // // TODO: ask client for tokens
- // }
+ err = c.updateTokens(srv)
+ if err != nil {
+ log.Logger.Debugf("failed to update tokens for disconnect: %v", err)
+ }
err = server.Disconnect(ck.Context(), srv)
if err != nil {
return err
}
- // TODO: Set tokens with callback
+ err = c.forwardTokens(srv)
+ if err != nil {
+ log.Logger.Debugf("failed to forward tokens after disconnect: %v", err)
+ }
return nil
}
@@ -740,6 +827,13 @@ func (c *Client) RenewSession(ck *cookie.Cookie) (err error) {
c.FSM.GoTransition(StateLoadingServer)
c.FSM.GoTransition(StateChosenServer)
}
+ // update tokens in the end
+ defer func() {
+ terr := c.forwardTokens(srv)
+ if terr != nil {
+ log.Logger.Debugf("failed to forward tokens after renew: %v", terr)
+ }
+ }()
// TODO: Maybe this can be deleted because we force auth now
server.MarkTokensForRenew(srv)
// run the callbacks by forcing auth
diff --git a/exports/exports.go b/exports/exports.go
index f33fedd..6fc7f33 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -9,6 +9,9 @@ typedef long long int (*ReadRxBytes)();
typedef int (*StateCB)(int oldstate, int newstate, void* data);
+typedef const char* (*TokenGetter)(const char* server, char* out, size_t len);
+typedef void (*TokenSetter)(const char* server, const char* tokens);
+
static long long int get_read_rx_bytes(ReadRxBytes read)
{
return read();
@@ -17,10 +20,19 @@ static int call_callback(StateCB callback, int oldstate, int newstate, void* dat
{
return callback(oldstate, newstate, data);
}
+static void call_token_getter(TokenGetter getter, const char* server, char* out, size_t len)
+{
+ getter(server, out, len);
+}
+static void call_token_setter(TokenSetter setter, const char* server, const char* tokens)
+{
+ setter(server, tokens);
+}
*/
import "C"
import (
+ "bytes"
"context"
"encoding/json"
"runtime/cgo"
@@ -29,6 +41,7 @@ import (
"github.com/go-errors/errors"
"github.com/eduvpn/eduvpn-common/client"
+ "github.com/eduvpn/eduvpn-common/internal/log"
"github.com/eduvpn/eduvpn-common/types/cookie"
srvtypes "github.com/eduvpn/eduvpn-common/types/server"
)
@@ -382,6 +395,64 @@ func getCookie(c C.uintptr_t) (*cookie.Cookie, error) {
return v, nil
}
+//export SetTokenHandler
+func SetTokenHandler(getter C.TokenGetter, setter C.TokenSetter) *C.char {
+ state, stateErr := getVPNState()
+ if stateErr != nil {
+ return getCError(stateErr)
+ }
+ state.TokenSetter = func(c srvtypes.Current, t srvtypes.Tokens) {
+ cJSON, err := getReturnData(c)
+ if err != nil {
+ log.Logger.Warningf("failed to get current server for setting tokens in exports: %v", err)
+ return
+ }
+ tJSON, err := getReturnData(t)
+ if err != nil {
+ log.Logger.Warningf("failed to get tokens for setting tokens in exports: %v", err)
+ return
+ }
+ c1 := C.CString(cJSON)
+ c2 := C.CString(tJSON)
+ C.call_token_setter(setter, c1, c2)
+ FreeString(c1)
+ FreeString(c2)
+ }
+
+ state.TokenGetter = func(c srvtypes.Current) *srvtypes.Tokens {
+ cJSON, err := getReturnData(c)
+ if err != nil {
+ log.Logger.Warningf("failed to get current server for getting tokens in exports: %v", err)
+ return nil
+ }
+ c1 := C.CString(cJSON)
+ // create an output buffer with size 2048
+ // In my testing tokens seem to be ~1033 bytes marshalled as JSON
+ d := make([]byte, 2048)
+
+ C.call_token_getter(getter, c1, (*C.char)(unsafe.Pointer(&d[0])), C.size_t(len(d)))
+ FreeString(c1)
+
+ // get null pointer index as unmarshalling wants it without
+ null := bytes.IndexByte(d, 0)
+ if null < 0 {
+ log.Logger.Warningf("output buffer is not NULL terminated")
+ return nil
+ }
+
+ var gotT srvtypes.Tokens
+ err = json.Unmarshal(d[:null], &gotT)
+ if err != nil {
+ log.Logger.Warningf("failed to get json data for getting tokens in exports: %v", err)
+ return nil
+ }
+ return &gotT
+ }
+
+
+ return nil
+}
+
//export CookieNew
func CookieNew() C.uintptr_t {
c := cookie.NewWithContext(context.Background())
diff --git a/internal/server/server.go b/internal/server/server.go
index e7229c5..c34158a 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -108,9 +108,6 @@ type ConfigData struct {
// The type of configuration
Type string
-
- // The tokens
- Tokens oauth.Token
}
// Public gets the public data from the types package
@@ -120,7 +117,6 @@ func (c *ConfigData) Public(dg bool) srvtypes.Configuration {
VPNConfig: c.Config,
Protocol: protocol.New(c.Type),
DefaultGateway: dg,
- Tokens: c.Tokens.Public(),
}
}
@@ -154,13 +150,7 @@ func wireguardGetConfig(ctx context.Context, srv Server, preferTCP bool, openVPN
cfg = wireguard.ConfigAddKey(cfg, key)
}
- t := oauth.Token{}
- o := srv.OAuth()
- if o != nil {
- t = o.Token()
- }
-
- return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil
+ return &ConfigData{Config: cfg, Type: proto}, nil
}
func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigData, error) {
@@ -178,14 +168,7 @@ func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigD
b.StartTime = time.Now()
b.EndTime = exp
- t := oauth.Token{}
-
- o := srv.OAuth()
- if o != nil {
- t = o.Token()
- }
-
- return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil
+ return &ConfigData{Config: cfg, Type: "openvpn"}, nil
}
func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bool, error) {
diff --git a/types/server/server.go b/types/server/server.go
index 1578f29..aea496d 100644
--- a/types/server/server.go
+++ b/types/server/server.go
@@ -130,9 +130,6 @@ type Configuration struct {
Protocol protocol.Protocol `json:"protocol"`
// DefaultGateway is a boolean that indicates whether or not this configuration should be configured as a default gateway
DefaultGateway bool `json:"default_gateway"`
- // Tokens is the updated tokens that we get back from the VPN configuration
- // They should be used by the client to save them in e.g. the keyring
- Tokens Tokens `json:"tokens"`
}
// Current is the struct that defines the current server
diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py
index 961b569..1a172af 100644
--- a/wrappers/python/eduvpn_common/loader.py
+++ b/wrappers/python/eduvpn_common/loader.py
@@ -4,7 +4,7 @@ from collections import defaultdict
from ctypes import CDLL, c_char_p, c_int, c_void_p, cdll
from eduvpn_common import __version__
-from eduvpn_common.types import BoolError, DataError, ReadRxBytes, VPNStateChange
+from eduvpn_common.types import BoolError, DataError, ReadRxBytes, TokenGetter, TokenSetter, VPNStateChange
def load_lib() -> CDLL:
@@ -88,6 +88,7 @@ def initialize_functions(lib: CDLL) -> None:
c_int,
], c_void_p
lib.RenewSession.argtypes, lib.RenewSession.restype = [c_int], c_void_p
+ lib.SetTokenHandler.argtypes, lib.SetTokenHandler.restype = [TokenGetter, TokenSetter], c_void_p
lib.Cleanup.argtypes, lib.Cleanup.restype = [c_int], c_void_p
lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p], c_void_p
lib.CookieNew.argtypes, lib.CookieNew.restype = [], c_int
diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py
index b10e641..5d08ba9 100644
--- a/wrappers/python/eduvpn_common/main.py
+++ b/wrappers/python/eduvpn_common/main.py
@@ -1,8 +1,9 @@
+import ctypes
from enum import IntEnum
from typing import Any, Callable, Iterator, Optional
from eduvpn_common.loader import initialize_functions, load_lib
-from eduvpn_common.types import ReadRxBytes, VPNStateChange, decode_res, encode_args
+from eduvpn_common.types import ReadRxBytes, TokenGetter, TokenSetter, VPNStateChange, decode_res, encode_args
class WrappedError(Exception):
@@ -56,6 +57,9 @@ class EduVPN(object):
self.version = version
self.config_directory = config_directory
self.jar = Jar(lambda x: self.go_function(self.lib.CookieCancel, x))
+ self.callback = None
+ self.token_setter = None
+ self.token_getter = None
# Load the library
self.lib = load_lib()
@@ -88,8 +92,8 @@ class EduVPN(object):
This removes the object from internal bookkeeping and saves the configuration
"""
self.go_function(self.lib.Deregister)
- global callback_object
- callback_object = None
+ global global_object
+ global_object = None
def register(self, handler: Optional[Callable] = None, debug: bool = False) -> None:
"""Register the Go shared library.
@@ -99,10 +103,11 @@ class EduVPN(object):
:param debug: bool: (Default value = False): Whether or not we want to enable debug logging
"""
- global callback_object
- if callback_object is not None:
+ global global_object
+ if global_object is not None:
raise Exception("Already registered")
- callback_object = handler
+ self.callback = handler
+ global_object = self
register_err = self.go_function(
self.lib.Register,
self.name,
@@ -244,6 +249,14 @@ class EduVPN(object):
if location_err:
forwardError(location_err)
+ def set_token_handler(self, getter: Callable, setter: Callable) -> None:
+ self.token_setter = setter
+ self.token_getter = getter
+ handler_err = self.go_function(self.lib.SetTokenHandler, token_getter, token_setter)
+
+ if handler_err:
+ forwardError(handler_err)
+
def cookie_reply(self, cookie: int, data: str) -> None:
"""Reply with the given cookie and data"""
cookie_err = self.go_function(self.lib.CookieReply, cookie, data)
@@ -289,8 +302,30 @@ class EduVPN(object):
self.jar.cancel()
-callback_object: Optional[Callable] = None
+global_object: Optional[EduVPN] = None
+@TokenSetter
+def token_setter(server: ctypes.c_char_p, tokens: ctypes.c_char_p):
+ global global_object
+ if global_object is None:
+ return
+ if global_object.token_setter is None:
+ return 0
+ global_object.token_setter(server.decode(), tokens.decode())
+
+@TokenGetter
+def token_getter(server: ctypes.c_char_p, buf: ctypes.c_char_p, size: ctypes.c_size_t):
+ global global_object
+ if global_object is None:
+ return
+ if global_object.token_getter is None:
+ return
+ got = global_object.token_getter(server.decode())
+ if got is None:
+ return
+
+ outbuf = ctypes.cast(buf, ctypes.POINTER(ctypes.c_char * size))
+ outbuf.contents.value = got.encode("utf-8")
@VPNStateChange
def state_callback(old_state: int, new_state: int, data: str) -> int:
@@ -302,9 +337,12 @@ def state_callback(old_state: int, new_state: int, data: str) -> int:
:meta private:
"""
- if callback_object is None:
+ global global_object
+ if global_object is None:
+ return 0
+ if global_object.callback is None:
return 0
- handled = callback_object(old_state, new_state, data.decode("utf-8"))
+ handled = global_object.callback(old_state, new_state, data.decode("utf-8"))
if handled:
return 1
return 0
diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py
index 1eba468..f83e710 100644
--- a/wrappers/python/eduvpn_common/types.py
+++ b/wrappers/python/eduvpn_common/types.py
@@ -1,10 +1,13 @@
from ctypes import (
CDLL,
CFUNCTYPE,
+ POINTER,
Structure,
+ c_char,
c_char_p,
c_int,
c_ulonglong,
+ c_size_t,
c_void_p,
cast,
)
@@ -32,8 +35,8 @@ class BoolError(Structure):
# The type for a Go state change callback
VPNStateChange = CFUNCTYPE(c_int, c_int, c_int, c_char_p)
ReadRxBytes = CFUNCTYPE(c_ulonglong)
-UpdateToken = CFUNCTYPE(None, c_char_p, c_void_p, c_void_p)
-
+TokenGetter = CFUNCTYPE(c_void_p, c_char_p, POINTER(c_char), c_size_t)
+TokenSetter = CFUNCTYPE(c_void_p, c_char_p, c_char_p)
def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]:
"""Encode the arguments ready to be used by the Go library