summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2024-07-10 14:39:34 +0200
committerJeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com>2024-07-17 14:00:03 +0000
commita1879195a727d7b90347ed11f86d85fac6541df7 (patch)
treeef19423671009552181f759b4a9162e7d91bf82a
parent7f8af5845ddec1816f93a2cb013f0818c19caab3 (diff)
Client + Discovery: Fetch dscovery at startup using DiscoveryStartup
With a manager that locks and copies such that no race conditions happen
-rw-r--r--client/client.go80
-rw-r--r--client/discovery.go80
-rw-r--r--exports/exports.go27
-rw-r--r--internal/api/api_test.go2
-rw-r--r--internal/discovery/discovery.go21
-rw-r--r--internal/server/secureinternet.go12
-rw-r--r--wrappers/python/eduvpn_common/loader.py2
-rw-r--r--wrappers/python/eduvpn_common/main.py13
-rw-r--r--wrappers/python/eduvpn_common/types.py1
9 files changed, 203 insertions, 35 deletions
diff --git a/client/client.go b/client/client.go
index 9b51d82..0fd4230 100644
--- a/client/client.go
+++ b/client/client.go
@@ -53,20 +53,12 @@ type Client struct {
proxy Proxy
mu sync.Mutex
+
+ discoMan DiscoManager
}
-// MarkOrganizationsExpired marks the discovery organization list as expired
-// it's a no-op if the type `t` is not secure internet
-// or if discovery is nil
-func (c *Client) MarkOrganizationsExpired(t srvtypes.Type) {
- if t != srvtypes.TypeSecureInternet {
- return
- }
- disco := c.cfg.Discovery()
- if disco == nil {
- return
- }
- disco.MarkOrganizationsExpired()
+func (c *Client) DiscoveryStartup(cb func()) {
+ c.discoMan.Startup(context.Background(), cb)
}
// GettingConfig is defined here to satisfy the server.Callbacks interface
@@ -167,12 +159,19 @@ func New(name string, version string, directory string, stateCallback func(FSMSt
// set the servers
c.Servers = server.NewServers(c.Name, c, c.cfg.V2)
- // the first fetch for the servers should be fresh
- c.cfg.Discovery().MarkServersExpired()
+ c.discoMan = DiscoManager{disco: c.cfg.Discovery()}
+
+ if !c.hasDiscovery() {
+ return c, nil
+ }
+ disco, release := c.discoMan.Discovery(true)
+ defer release()
+ disco.MarkServersExpired()
if !c.cfg.HasSecureInternet() {
- c.cfg.Discovery().MarkOrganizationsExpired()
+ disco.MarkOrganizationsExpired()
}
+
return c, nil
}
@@ -255,6 +254,9 @@ func (c *Client) Register() error {
// Deregister 'deregisters' the client, meaning saving the log file and the config and emptying out the client struct.
func (c *Client) Deregister() {
+ c.discoMan.Cancel()
+
+ _, release := c.discoMan.Discovery(false)
// save the config
c.TrySave()
@@ -266,6 +268,7 @@ func (c *Client) Deregister() {
// Close the log file
_ = log.Logger.Close()
+ release()
// Empty out the state
*c = Client{}
@@ -290,8 +293,8 @@ func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) {
}, nil
}
-func (c *Client) locationCallback(ck *cookie.Cookie, orgID string) error {
- locs := c.cfg.Discovery().SecureLocationList()
+func (c *Client) locationCallback(ck *cookie.Cookie, disco *discovery.Discovery, orgID string) error {
+ locs := disco.SecureLocationList()
errChan := make(chan error)
go func() {
err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{
@@ -342,6 +345,8 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
// If we have failed to add the server, we remove it again
// We add the server because we can then obtain it in other callback functions
previousState := c.FSM.Current
+
+ var release func()
defer func() {
// If we must run callbacks, go to the previous state if we're not in it
if !ni && !c.FSM.InState(previousState) {
@@ -350,6 +355,9 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
if err == nil {
c.TrySave()
}
+ if release != nil {
+ release()
+ }
}()
if !ni {
@@ -367,14 +375,16 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
}
}
+ var disco *discovery.Discovery
+ disco, release = c.discoMan.Discovery(true)
switch _type {
case srvtypes.TypeInstituteAccess:
- err = c.Servers.AddInstitute(ck.Context(), c.cfg.Discovery(), identifier, ot)
+ err = c.Servers.AddInstitute(ck.Context(), disco, identifier, ot)
if err != nil {
return i18nerr.Wrapf(err, "Failed to add an institute access server with URL: '%s'", identifier)
}
case srvtypes.TypeSecureInternet:
- err = c.Servers.AddSecure(ck.Context(), c.cfg.Discovery(), identifier, ot)
+ err = c.Servers.AddSecure(ck.Context(), disco, identifier, ot)
if err != nil {
return i18nerr.Wrapf(err, "Failed to add a secure internet server with organisation ID: '%s'", identifier)
}
@@ -412,6 +422,7 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.
return nil, i18nerr.NewInternalf("Getting a non-custom server when the client does not use discovery is not supported, identifier: %s, type: %d", identifier, _type)
}
+ var release func()
defer func() {
if err == nil {
// it could be that we are not in getting config yet if we have just done authorization
@@ -422,6 +433,9 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.
c.FSM.GoTransition(previousState) //nolint:errcheck
}
c.TrySave()
+ if release != nil {
+ release()
+ }
}()
identifier, err = c.convertIdentifier(identifier, _type)
@@ -439,7 +453,8 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.
}
ctx := ck.Context()
- disco := c.cfg.Discovery()
+ var disco *discovery.Discovery
+ disco, release = c.discoMan.Discovery(true)
if _type != srvtypes.TypeCustom {
// make sure the servers are fetched fresh
_, _, dserverr := disco.Servers(ctx)
@@ -462,7 +477,7 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.
var cErr *discovery.ErrCountryNotFound
if errors.As(err, &cErr) {
- err = c.locationCallback(ck, identifier)
+ err = c.locationCallback(ck, disco, identifier)
if err == nil {
srv, err = c.Servers.GetSecure(ctx, identifier, disco, tok, startup)
}
@@ -499,14 +514,21 @@ func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error
if err != nil {
return i18nerr.WrapInternalf(err, "Failed to remove server: '%s'", identifier)
}
- c.MarkOrganizationsExpired(_type)
+ disco, release := c.discoMan.Discovery(true)
+ defer release()
+ if _type == srvtypes.TypeSecureInternet {
+ disco.MarkOrganizationsExpired()
+ }
c.TrySave()
return nil
}
// CurrentServer gets the current server that is configured
func (c *Client) CurrentServer() (*srvtypes.Current, error) {
- curr, err := c.Servers.PublicCurrent(c.cfg.Discovery())
+ // TODO: do clients call this during a write mutex?
+ disco, release := c.discoMan.Discovery(false)
+ defer release()
+ curr, err := c.Servers.PublicCurrent(disco)
if err != nil {
return nil, i18nerr.WrapInternal(err, "The current server could not be retrieved")
}
@@ -561,7 +583,9 @@ func (c *Client) Cleanup(ck *cookie.Cookie) error {
if err != nil {
return i18nerr.WrapInternal(err, "No OAuth tokens were found when cleaning up the connection")
}
- auth, err := srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), tok, true)
+ disco, release := c.discoMan.Discovery(true)
+ defer release()
+ auth, err := srv.ServerWithCallbacks(ck.Context(), disco, tok, true)
if err != nil {
return i18nerr.WrapInternal(err, "The server was unable to be retrieved when cleaning up the connection")
}
@@ -608,7 +632,9 @@ func (c *Client) RenewSession(ck *cookie.Cookie) error {
}
// getting a server with no tokens means re-authorize
- _, err = srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), nil, false)
+ disco, release := c.discoMan.Discovery(true)
+ defer release()
+ _, err = srv.ServerWithCallbacks(ck.Context(), disco, nil, false)
if err != nil {
return i18nerr.WrapInternal(err, "The server was unable to be retrieved when renewing the session")
}
@@ -629,6 +655,8 @@ func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readR
// ServerList gets the list of servers
func (c *Client) ServerList() (*srvtypes.List, error) {
- g := c.cfg.V2.PublicList(c.cfg.Discovery())
+ disco, release := c.discoMan.Discovery(false)
+ defer release()
+ g := c.cfg.V2.PublicList(disco)
return g, nil
}
diff --git a/client/discovery.go b/client/discovery.go
index 415be9b..2758678 100644
--- a/client/discovery.go
+++ b/client/discovery.go
@@ -1,9 +1,13 @@
package client
import (
+ "context"
"sort"
"strings"
+ "sync"
+ "github.com/eduvpn/eduvpn-common/internal/discovery"
+ "github.com/eduvpn/eduvpn-common/internal/log"
"github.com/eduvpn/eduvpn-common/i18nerr"
"github.com/eduvpn/eduvpn-common/types/cookie"
discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
@@ -24,7 +28,10 @@ func (c *Client) DiscoOrganizations(ck *cookie.Cookie, search string) (*discotyp
return nil, i18nerr.NewInternal("Server/organization discovery with this client ID is not supported")
}
- orgs, fresh, err := c.cfg.Discovery().Organizations(ck.Context())
+ disco, release := c.discoMan.Discovery(true)
+ defer release()
+
+ orgs, fresh, err := disco.Organizations(ck.Context())
if fresh {
defer c.TrySave()
}
@@ -70,7 +77,9 @@ func (c *Client) DiscoServers(ck *cookie.Cookie, search string) (*discotypes.Ser
return nil, i18nerr.NewInternal("Server/organization discovery with this client ID is not supported")
}
- servs, fresh, err := c.cfg.Discovery().Servers(ck.Context())
+ disco, release := c.discoMan.Discovery(true)
+ defer release()
+ servs, fresh, err := disco.Servers(ck.Context())
if fresh {
defer c.TrySave()
}
@@ -105,3 +114,70 @@ func (c *Client) DiscoServers(ck *cookie.Cookie, search string) (*discotypes.Ser
List: retServs,
}, err
}
+
+type DiscoManager struct {
+ disco *discovery.Discovery
+
+ cancel context.CancelFunc
+ mu sync.RWMutex
+ wait sync.WaitGroup
+}
+
+func (m *DiscoManager) lock(write bool) {
+ if write {
+ m.mu.Lock()
+ return
+ }
+ m.mu.RLock()
+}
+
+func (m *DiscoManager) unlock(write bool) {
+ if write {
+ m.mu.Unlock()
+ return
+ }
+ m.mu.RUnlock()
+}
+
+func (m *DiscoManager) Discovery(write bool) (*discovery.Discovery, func()) {
+ if write {
+ m.wait.Wait()
+ }
+ m.lock(write)
+ return m.disco, func() {
+ m.unlock(write)
+ }
+}
+
+func (m *DiscoManager) Cancel() {
+ if m.cancel != nil {
+ m.cancel()
+ }
+ m.wait.Wait()
+}
+
+func (m *DiscoManager) Startup(ctx context.Context, cb func()) {
+ ctx, cancel := context.WithCancel(ctx)
+ m.cancel = cancel
+ m.wait.Add(1)
+ go func() {
+ defer m.wait.Done()
+ m.lock(false)
+ discoCopy, err := m.disco.Copy()
+ if err != nil {
+ log.Logger.Warningf("internal error, failed to clone discovery, %v", err)
+ return
+ }
+ m.unlock(false)
+ // we already log the warning
+ discoCopy.Servers(ctx) //nolint:errcheck
+
+ m.lock(true)
+ m.disco.UpdateServers(discoCopy)
+ m.unlock(true)
+
+ if cb != nil {
+ cb()
+ }
+ }()
+}
diff --git a/exports/exports.go b/exports/exports.go
index 1e496f8..36372b4 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -21,6 +21,7 @@ typedef long long int (*ReadRxBytes)();
typedef int (*StateCB)(int oldstate, int newstate, void* data);
+typedef void (*RefreshList)();
typedef void (*TokenGetter)(const char* server_id, int server_type, char* out, size_t len);
typedef void (*TokenSetter)(const char* server_id, int server_type, const char* tokens);
typedef void (*ProxySetup)(int fd, const char* peer_ips);
@@ -34,6 +35,10 @@ static int call_callback(StateCB callback, int oldstate, int newstate, void* dat
{
return callback(oldstate, newstate, data);
}
+static void call_refresh_list(RefreshList refresh)
+{
+ refresh();
+}
static void call_token_getter(TokenGetter getter, const char* server_id, int server_type, char* out, size_t len)
{
getter(server_id, server_type, out, len);
@@ -1024,6 +1029,28 @@ func getCookie(c C.uintptr_t) (*cookie.Cookie, error) {
return v, nil
}
+
+// DiscoveryStartup does a discovery request in the background
+//
+// The `refresh` argument is a callback that is called when the refreshing is done
+// When this callback is thus called, the app SHOULD refresh the server list of the already configured servers
+// This DiscoveryStartup function MUST be called after calling `Register`
+//
+//export DiscoveryStartup
+func DiscoveryStartup(refresh C.RefreshList) *C.char {
+ state, stateErr := getVPNState()
+ if stateErr != nil {
+ return getCError(stateErr)
+ }
+ state.DiscoveryStartup(func() {
+ if refresh == nil {
+ return
+ }
+ C.call_refresh_list(refresh)
+ })
+ return nil
+}
+
// SetTokenHandler sets the token getters and token setters for OAuth
//
// Because the data that is saved does not contain OAuth tokens for server, the common lib asks and sets the tokens using these callback functions.
diff --git a/internal/api/api_test.go b/internal/api/api_test.go
index 397dd3c..fcf02e9 100644
--- a/internal/api/api_test.go
+++ b/internal/api/api_test.go
@@ -196,7 +196,7 @@ func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.Ha
Type: server.TypeCustom,
BaseWK: serv.URL,
BaseAuthWK: serv.URL,
- ProcessAuth: func(ctx context.Context, in string) (string, error) {
+ ProcessAuth: func(_ context.Context, in string) (string, error) {
return in, nil
},
DisableAuthorize: false,
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index 30ca801..231c12a 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -405,3 +405,24 @@ func (discovery *Discovery) Servers(ctx context.Context) (*Servers, bool, error)
}
return &discovery.ServerList, true, nil
}
+
+func (discovery *Discovery) UpdateServers(other Discovery) {
+ if other.ServerList.Version >= discovery.ServerList.Version {
+ discovery.ServerList = other.ServerList
+ }
+}
+
+func (discovery *Discovery) Copy() (Discovery, error) {
+ var dest Discovery
+ b, err := json.Marshal(discovery)
+ if err != nil {
+ return dest, err
+ }
+
+ err = json.Unmarshal(b, &dest)
+ if err != nil {
+ return dest, err
+ }
+
+ return dest, nil
+}
diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go
index f167756..990ceb3 100644
--- a/internal/server/secureinternet.go
+++ b/internal/server/secureinternet.go
@@ -34,8 +34,10 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org
BaseWK: dsrv.BaseURL,
BaseAuthWK: dsrv.BaseURL,
ProcessAuth: func(ctx context.Context, url string) (string, error) {
- disco.Servers(ctx)
- disco.Organizations(ctx)
+ // the only thing we can do is log warn
+ // this is already done in the functions
+ disco.Servers(ctx) //nolint:errcheck
+ disco.Organizations(ctx) //nolint:errcheck
updorg, updsrv, err := disco.SecureHomeArgs(orgID)
if err != nil {
return "", err
@@ -104,10 +106,12 @@ func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery.
BaseWK: dloc.BaseURL,
BaseAuthWK: dhome.BaseURL,
ProcessAuth: func(ctx context.Context, url string) (string, error) {
+ // the only thing we can do is log warn
+ // this is already done in the functions
disco.MarkServersExpired()
- disco.Servers(ctx)
+ disco.Servers(ctx) //nolint:errcheck
disco.MarkOrganizationsExpired()
- disco.Organizations(ctx)
+ disco.Organizations(ctx) //nolint:errcheck
updorg, updsrv, err := disco.SecureHomeArgs(orgID)
if err != nil {
return "", err
diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py
index b74741f..4bfc55f 100644
--- a/wrappers/python/eduvpn_common/loader.py
+++ b/wrappers/python/eduvpn_common/loader.py
@@ -8,6 +8,7 @@ from eduvpn_common.types import (
ProxyReady,
ProxySetup,
ReadRxBytes,
+ RefreshList,
TokenGetter,
TokenSetter,
VPNStateChange,
@@ -90,6 +91,7 @@ def initialize_functions(lib: CDLL) -> None:
c_void_p,
)
lib.RenewSession.argtypes, lib.RenewSession.restype = [c_int], c_void_p
+ lib.DiscoveryStartup.argtypes, lib.DiscoveryStartup.restype = [RefreshList], c_void_p
lib.SetTokenHandler.argtypes, lib.SetTokenHandler.restype = (
[
TokenGetter,
diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py
index cb81e53..8c556e9 100644
--- a/wrappers/python/eduvpn_common/main.py
+++ b/wrappers/python/eduvpn_common/main.py
@@ -10,6 +10,7 @@ from eduvpn_common.types import (
ProxyReady,
ProxySetup,
ReadRxBytes,
+ RefreshList,
TokenGetter,
TokenSetter,
VPNStateChange,
@@ -180,12 +181,14 @@ class EduVPN(object):
def get_disco_organizations(self, search="") -> str:
orgs, _ = self.go_cookie_function(self.lib.DiscoOrganizations, search)
- # TODO: Log error
+ # We don't log anything here as we want to return a result and we don't want to throw here
+ # we already log for errors in common
return orgs
def get_disco_servers(self, search="") -> str:
servers, _ = self.go_cookie_function(self.lib.DiscoServers, search)
- # TODO: Log error
+ # We don't log anything here as we want to return a result and we don't want to throw here
+ # we already log for errors in common
return servers
def get_servers(self) -> str:
@@ -298,6 +301,12 @@ class EduVPN(object):
if location_err:
forwardError(location_err)
+ def discovery_startup(self, refresh: RefreshList) -> None:
+ refresh_err = self.go_function(self.lib.DiscoveryStartup, refresh)
+
+ if refresh_err:
+ forwardError(refresh_err)
+
def set_token_handler(self, getter: Callable, setter: Callable) -> None:
self.token_setter = setter
self.token_getter = getter
diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py
index 21690fb..716556e 100644
--- a/wrappers/python/eduvpn_common/types.py
+++ b/wrappers/python/eduvpn_common/types.py
@@ -39,6 +39,7 @@ VPNStateChange = CFUNCTYPE(c_int, c_int, c_int, c_char_p)
ProxySetup = CFUNCTYPE(c_void_p, c_int, c_char_p)
ProxyReady = CFUNCTYPE(c_void_p)
ReadRxBytes = CFUNCTYPE(c_ulonglong)
+RefreshList = CFUNCTYPE(c_void_p)
TokenGetter = CFUNCTYPE(c_void_p, c_char_p, c_int, POINTER(c_char), c_size_t)
TokenSetter = CFUNCTYPE(c_void_p, c_char_p, c_int, c_char_p)