From a1879195a727d7b90347ed11f86d85fac6541df7 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Wed, 10 Jul 2024 14:39:34 +0200 Subject: Client + Discovery: Fetch dscovery at startup using DiscoveryStartup With a manager that locks and copies such that no race conditions happen --- client/client.go | 80 ++++++++++++++++++++++++++++++++++++----------------- client/discovery.go | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 132 insertions(+), 28 deletions(-) (limited to 'client') diff --git a/client/client.go b/client/client.go index 9b51d82..0fd4230 100644 --- a/client/client.go +++ b/client/client.go @@ -53,20 +53,12 @@ type Client struct { proxy Proxy mu sync.Mutex + + discoMan DiscoManager } -// MarkOrganizationsExpired marks the discovery organization list as expired -// it's a no-op if the type `t` is not secure internet -// or if discovery is nil -func (c *Client) MarkOrganizationsExpired(t srvtypes.Type) { - if t != srvtypes.TypeSecureInternet { - return - } - disco := c.cfg.Discovery() - if disco == nil { - return - } - disco.MarkOrganizationsExpired() +func (c *Client) DiscoveryStartup(cb func()) { + c.discoMan.Startup(context.Background(), cb) } // GettingConfig is defined here to satisfy the server.Callbacks interface @@ -167,12 +159,19 @@ func New(name string, version string, directory string, stateCallback func(FSMSt // set the servers c.Servers = server.NewServers(c.Name, c, c.cfg.V2) - // the first fetch for the servers should be fresh - c.cfg.Discovery().MarkServersExpired() + c.discoMan = DiscoManager{disco: c.cfg.Discovery()} + + if !c.hasDiscovery() { + return c, nil + } + disco, release := c.discoMan.Discovery(true) + defer release() + disco.MarkServersExpired() if !c.cfg.HasSecureInternet() { - c.cfg.Discovery().MarkOrganizationsExpired() + disco.MarkOrganizationsExpired() } + return c, nil } @@ -255,6 +254,9 @@ func (c *Client) Register() error { // Deregister 'deregisters' the client, meaning saving the log file and the config and emptying out the client struct. func (c *Client) Deregister() { + c.discoMan.Cancel() + + _, release := c.discoMan.Discovery(false) // save the config c.TrySave() @@ -266,6 +268,7 @@ func (c *Client) Deregister() { // Close the log file _ = log.Logger.Close() + release() // Empty out the state *c = Client{} @@ -290,8 +293,8 @@ func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { }, nil } -func (c *Client) locationCallback(ck *cookie.Cookie, orgID string) error { - locs := c.cfg.Discovery().SecureLocationList() +func (c *Client) locationCallback(ck *cookie.Cookie, disco *discovery.Discovery, orgID string) error { + locs := disco.SecureLocationList() errChan := make(chan error) go func() { err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{ @@ -342,6 +345,8 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. // If we have failed to add the server, we remove it again // We add the server because we can then obtain it in other callback functions previousState := c.FSM.Current + + var release func() defer func() { // If we must run callbacks, go to the previous state if we're not in it if !ni && !c.FSM.InState(previousState) { @@ -350,6 +355,9 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. if err == nil { c.TrySave() } + if release != nil { + release() + } }() if !ni { @@ -367,14 +375,16 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. } } + var disco *discovery.Discovery + disco, release = c.discoMan.Discovery(true) switch _type { case srvtypes.TypeInstituteAccess: - err = c.Servers.AddInstitute(ck.Context(), c.cfg.Discovery(), identifier, ot) + err = c.Servers.AddInstitute(ck.Context(), disco, identifier, ot) if err != nil { return i18nerr.Wrapf(err, "Failed to add an institute access server with URL: '%s'", identifier) } case srvtypes.TypeSecureInternet: - err = c.Servers.AddSecure(ck.Context(), c.cfg.Discovery(), identifier, ot) + err = c.Servers.AddSecure(ck.Context(), disco, identifier, ot) if err != nil { return i18nerr.Wrapf(err, "Failed to add a secure internet server with organisation ID: '%s'", identifier) } @@ -412,6 +422,7 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. return nil, i18nerr.NewInternalf("Getting a non-custom server when the client does not use discovery is not supported, identifier: %s, type: %d", identifier, _type) } + var release func() defer func() { if err == nil { // it could be that we are not in getting config yet if we have just done authorization @@ -422,6 +433,9 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. c.FSM.GoTransition(previousState) //nolint:errcheck } c.TrySave() + if release != nil { + release() + } }() identifier, err = c.convertIdentifier(identifier, _type) @@ -439,7 +453,8 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. } ctx := ck.Context() - disco := c.cfg.Discovery() + var disco *discovery.Discovery + disco, release = c.discoMan.Discovery(true) if _type != srvtypes.TypeCustom { // make sure the servers are fetched fresh _, _, dserverr := disco.Servers(ctx) @@ -462,7 +477,7 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. var cErr *discovery.ErrCountryNotFound if errors.As(err, &cErr) { - err = c.locationCallback(ck, identifier) + err = c.locationCallback(ck, disco, identifier) if err == nil { srv, err = c.Servers.GetSecure(ctx, identifier, disco, tok, startup) } @@ -499,14 +514,21 @@ func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error if err != nil { return i18nerr.WrapInternalf(err, "Failed to remove server: '%s'", identifier) } - c.MarkOrganizationsExpired(_type) + disco, release := c.discoMan.Discovery(true) + defer release() + if _type == srvtypes.TypeSecureInternet { + disco.MarkOrganizationsExpired() + } c.TrySave() return nil } // CurrentServer gets the current server that is configured func (c *Client) CurrentServer() (*srvtypes.Current, error) { - curr, err := c.Servers.PublicCurrent(c.cfg.Discovery()) + // TODO: do clients call this during a write mutex? + disco, release := c.discoMan.Discovery(false) + defer release() + curr, err := c.Servers.PublicCurrent(disco) if err != nil { return nil, i18nerr.WrapInternal(err, "The current server could not be retrieved") } @@ -561,7 +583,9 @@ func (c *Client) Cleanup(ck *cookie.Cookie) error { if err != nil { return i18nerr.WrapInternal(err, "No OAuth tokens were found when cleaning up the connection") } - auth, err := srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), tok, true) + disco, release := c.discoMan.Discovery(true) + defer release() + auth, err := srv.ServerWithCallbacks(ck.Context(), disco, tok, true) if err != nil { return i18nerr.WrapInternal(err, "The server was unable to be retrieved when cleaning up the connection") } @@ -608,7 +632,9 @@ func (c *Client) RenewSession(ck *cookie.Cookie) error { } // getting a server with no tokens means re-authorize - _, err = srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), nil, false) + disco, release := c.discoMan.Discovery(true) + defer release() + _, err = srv.ServerWithCallbacks(ck.Context(), disco, nil, false) if err != nil { return i18nerr.WrapInternal(err, "The server was unable to be retrieved when renewing the session") } @@ -629,6 +655,8 @@ func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readR // ServerList gets the list of servers func (c *Client) ServerList() (*srvtypes.List, error) { - g := c.cfg.V2.PublicList(c.cfg.Discovery()) + disco, release := c.discoMan.Discovery(false) + defer release() + g := c.cfg.V2.PublicList(disco) return g, nil } diff --git a/client/discovery.go b/client/discovery.go index 415be9b..2758678 100644 --- a/client/discovery.go +++ b/client/discovery.go @@ -1,9 +1,13 @@ package client import ( + "context" "sort" "strings" + "sync" + "github.com/eduvpn/eduvpn-common/internal/discovery" + "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/i18nerr" "github.com/eduvpn/eduvpn-common/types/cookie" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" @@ -24,7 +28,10 @@ func (c *Client) DiscoOrganizations(ck *cookie.Cookie, search string) (*discotyp return nil, i18nerr.NewInternal("Server/organization discovery with this client ID is not supported") } - orgs, fresh, err := c.cfg.Discovery().Organizations(ck.Context()) + disco, release := c.discoMan.Discovery(true) + defer release() + + orgs, fresh, err := disco.Organizations(ck.Context()) if fresh { defer c.TrySave() } @@ -70,7 +77,9 @@ func (c *Client) DiscoServers(ck *cookie.Cookie, search string) (*discotypes.Ser return nil, i18nerr.NewInternal("Server/organization discovery with this client ID is not supported") } - servs, fresh, err := c.cfg.Discovery().Servers(ck.Context()) + disco, release := c.discoMan.Discovery(true) + defer release() + servs, fresh, err := disco.Servers(ck.Context()) if fresh { defer c.TrySave() } @@ -105,3 +114,70 @@ func (c *Client) DiscoServers(ck *cookie.Cookie, search string) (*discotypes.Ser List: retServs, }, err } + +type DiscoManager struct { + disco *discovery.Discovery + + cancel context.CancelFunc + mu sync.RWMutex + wait sync.WaitGroup +} + +func (m *DiscoManager) lock(write bool) { + if write { + m.mu.Lock() + return + } + m.mu.RLock() +} + +func (m *DiscoManager) unlock(write bool) { + if write { + m.mu.Unlock() + return + } + m.mu.RUnlock() +} + +func (m *DiscoManager) Discovery(write bool) (*discovery.Discovery, func()) { + if write { + m.wait.Wait() + } + m.lock(write) + return m.disco, func() { + m.unlock(write) + } +} + +func (m *DiscoManager) Cancel() { + if m.cancel != nil { + m.cancel() + } + m.wait.Wait() +} + +func (m *DiscoManager) Startup(ctx context.Context, cb func()) { + ctx, cancel := context.WithCancel(ctx) + m.cancel = cancel + m.wait.Add(1) + go func() { + defer m.wait.Done() + m.lock(false) + discoCopy, err := m.disco.Copy() + if err != nil { + log.Logger.Warningf("internal error, failed to clone discovery, %v", err) + return + } + m.unlock(false) + // we already log the warning + discoCopy.Servers(ctx) //nolint:errcheck + + m.lock(true) + m.disco.UpdateServers(discoCopy) + m.unlock(true) + + if cb != nil { + cb() + } + }() +} -- cgit v1.2.3