diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2024-07-10 14:39:34 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2024-07-17 14:00:03 +0000 |
| commit | a1879195a727d7b90347ed11f86d85fac6541df7 (patch) | |
| tree | ef19423671009552181f759b4a9162e7d91bf82a /client/client.go | |
| parent | 7f8af5845ddec1816f93a2cb013f0818c19caab3 (diff) | |
Client + Discovery: Fetch dscovery at startup using DiscoveryStartup
With a manager that locks and copies such that no race conditions happen
Diffstat (limited to 'client/client.go')
| -rw-r--r-- | client/client.go | 80 |
1 files changed, 54 insertions, 26 deletions
diff --git a/client/client.go b/client/client.go index 9b51d82..0fd4230 100644 --- a/client/client.go +++ b/client/client.go @@ -53,20 +53,12 @@ type Client struct { proxy Proxy mu sync.Mutex + + discoMan DiscoManager } -// MarkOrganizationsExpired marks the discovery organization list as expired -// it's a no-op if the type `t` is not secure internet -// or if discovery is nil -func (c *Client) MarkOrganizationsExpired(t srvtypes.Type) { - if t != srvtypes.TypeSecureInternet { - return - } - disco := c.cfg.Discovery() - if disco == nil { - return - } - disco.MarkOrganizationsExpired() +func (c *Client) DiscoveryStartup(cb func()) { + c.discoMan.Startup(context.Background(), cb) } // GettingConfig is defined here to satisfy the server.Callbacks interface @@ -167,12 +159,19 @@ func New(name string, version string, directory string, stateCallback func(FSMSt // set the servers c.Servers = server.NewServers(c.Name, c, c.cfg.V2) - // the first fetch for the servers should be fresh - c.cfg.Discovery().MarkServersExpired() + c.discoMan = DiscoManager{disco: c.cfg.Discovery()} + + if !c.hasDiscovery() { + return c, nil + } + disco, release := c.discoMan.Discovery(true) + defer release() + disco.MarkServersExpired() if !c.cfg.HasSecureInternet() { - c.cfg.Discovery().MarkOrganizationsExpired() + disco.MarkOrganizationsExpired() } + return c, nil } @@ -255,6 +254,9 @@ func (c *Client) Register() error { // Deregister 'deregisters' the client, meaning saving the log file and the config and emptying out the client struct. func (c *Client) Deregister() { + c.discoMan.Cancel() + + _, release := c.discoMan.Discovery(false) // save the config c.TrySave() @@ -266,6 +268,7 @@ func (c *Client) Deregister() { // Close the log file _ = log.Logger.Close() + release() // Empty out the state *c = Client{} @@ -290,8 +293,8 @@ func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { }, nil } -func (c *Client) locationCallback(ck *cookie.Cookie, orgID string) error { - locs := c.cfg.Discovery().SecureLocationList() +func (c *Client) locationCallback(ck *cookie.Cookie, disco *discovery.Discovery, orgID string) error { + locs := disco.SecureLocationList() errChan := make(chan error) go func() { err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{ @@ -342,6 +345,8 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. // If we have failed to add the server, we remove it again // We add the server because we can then obtain it in other callback functions previousState := c.FSM.Current + + var release func() defer func() { // If we must run callbacks, go to the previous state if we're not in it if !ni && !c.FSM.InState(previousState) { @@ -350,6 +355,9 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. if err == nil { c.TrySave() } + if release != nil { + release() + } }() if !ni { @@ -367,14 +375,16 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. } } + var disco *discovery.Discovery + disco, release = c.discoMan.Discovery(true) switch _type { case srvtypes.TypeInstituteAccess: - err = c.Servers.AddInstitute(ck.Context(), c.cfg.Discovery(), identifier, ot) + err = c.Servers.AddInstitute(ck.Context(), disco, identifier, ot) if err != nil { return i18nerr.Wrapf(err, "Failed to add an institute access server with URL: '%s'", identifier) } case srvtypes.TypeSecureInternet: - err = c.Servers.AddSecure(ck.Context(), c.cfg.Discovery(), identifier, ot) + err = c.Servers.AddSecure(ck.Context(), disco, identifier, ot) if err != nil { return i18nerr.Wrapf(err, "Failed to add a secure internet server with organisation ID: '%s'", identifier) } @@ -412,6 +422,7 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. return nil, i18nerr.NewInternalf("Getting a non-custom server when the client does not use discovery is not supported, identifier: %s, type: %d", identifier, _type) } + var release func() defer func() { if err == nil { // it could be that we are not in getting config yet if we have just done authorization @@ -422,6 +433,9 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. c.FSM.GoTransition(previousState) //nolint:errcheck } c.TrySave() + if release != nil { + release() + } }() identifier, err = c.convertIdentifier(identifier, _type) @@ -439,7 +453,8 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. } ctx := ck.Context() - disco := c.cfg.Discovery() + var disco *discovery.Discovery + disco, release = c.discoMan.Discovery(true) if _type != srvtypes.TypeCustom { // make sure the servers are fetched fresh _, _, dserverr := disco.Servers(ctx) @@ -462,7 +477,7 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. var cErr *discovery.ErrCountryNotFound if errors.As(err, &cErr) { - err = c.locationCallback(ck, identifier) + err = c.locationCallback(ck, disco, identifier) if err == nil { srv, err = c.Servers.GetSecure(ctx, identifier, disco, tok, startup) } @@ -499,14 +514,21 @@ func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error if err != nil { return i18nerr.WrapInternalf(err, "Failed to remove server: '%s'", identifier) } - c.MarkOrganizationsExpired(_type) + disco, release := c.discoMan.Discovery(true) + defer release() + if _type == srvtypes.TypeSecureInternet { + disco.MarkOrganizationsExpired() + } c.TrySave() return nil } // CurrentServer gets the current server that is configured func (c *Client) CurrentServer() (*srvtypes.Current, error) { - curr, err := c.Servers.PublicCurrent(c.cfg.Discovery()) + // TODO: do clients call this during a write mutex? + disco, release := c.discoMan.Discovery(false) + defer release() + curr, err := c.Servers.PublicCurrent(disco) if err != nil { return nil, i18nerr.WrapInternal(err, "The current server could not be retrieved") } @@ -561,7 +583,9 @@ func (c *Client) Cleanup(ck *cookie.Cookie) error { if err != nil { return i18nerr.WrapInternal(err, "No OAuth tokens were found when cleaning up the connection") } - auth, err := srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), tok, true) + disco, release := c.discoMan.Discovery(true) + defer release() + auth, err := srv.ServerWithCallbacks(ck.Context(), disco, tok, true) if err != nil { return i18nerr.WrapInternal(err, "The server was unable to be retrieved when cleaning up the connection") } @@ -608,7 +632,9 @@ func (c *Client) RenewSession(ck *cookie.Cookie) error { } // getting a server with no tokens means re-authorize - _, err = srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), nil, false) + disco, release := c.discoMan.Discovery(true) + defer release() + _, err = srv.ServerWithCallbacks(ck.Context(), disco, nil, false) if err != nil { return i18nerr.WrapInternal(err, "The server was unable to be retrieved when renewing the session") } @@ -629,6 +655,8 @@ func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readR // ServerList gets the list of servers func (c *Client) ServerList() (*srvtypes.List, error) { - g := c.cfg.V2.PublicList(c.cfg.Discovery()) + disco, release := c.discoMan.Discovery(false) + defer release() + g := c.cfg.V2.PublicList(disco) return g, nil } |
