From 4a23134e1e5d70a9c8c5857790dbf27585ca3b1f Mon Sep 17 00:00:00 2001 From: Jeroen Wijenbergh Date: Fri, 29 Aug 2025 14:05:20 +0200 Subject: Discovery: Add cache argument and embed unmarshal on startup --- internal/discovery/discovery.go | 96 ++++++++++++++++++------------------ internal/discovery/discovery_test.go | 18 +++---- internal/discovery/manager.go | 4 +- internal/server/secureinternet.go | 8 +-- 4 files changed, 63 insertions(+), 63 deletions(-) (limited to 'internal') diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 81163f8..80a69eb 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -293,44 +293,20 @@ func (discovery *Discovery) DetermineServersUpdate() bool { return !time.Now().Before(upd) } -func (discovery *Discovery) previousOrganizations() (*Organizations, error) { - // If the version field is not zero then we have a cached struct - // We also immediately return this copy if we have no embedded JSON - if discovery.OrganizationList.Version != 0 || !HasCache { - return &discovery.OrganizationList, nil +func (discovery *Discovery) cachedOrgs() *Organizations { + if discovery.OrganizationList.Version == 0 { + return nil } - - // We do not have a cached struct, this we need to get it using the embedded JSON - var eo Organizations - if err := json.Unmarshal(eOrganizations, &eo); err != nil { - return nil, fmt.Errorf("failed parsing discovery organizations from the embedded cache with error: %w", err) - } - discovery.OrganizationList = eo - return &eo, nil -} - -func (discovery *Discovery) previousServers() (*Servers, error) { - // If the version field is not zero then we have a cached struct - // We also immediately return this copy if we have no embedded JSON - if discovery.ServerList.Version != 0 || !HasCache { - return &discovery.ServerList, nil - } - - // We do not have a cached struct, this we need to get it using the embedded JSON - var es Servers - if err := json.Unmarshal(eServers, &es); err != nil { - return nil, fmt.Errorf("failed parsing discovery servers from the embedded cache with error: %w", err) - } - discovery.ServerList = es - return &es, nil + return &discovery.OrganizationList } // Organizations returns the discovery organizations // The second return value is a boolean that indicates whether a fresh list was updated internally // If there was an error, a cached copy is returned if available. -func (discovery *Discovery) Organizations(ctx context.Context) (*Organizations, bool, error) { - if !discovery.DetermineOrganizationsUpdate() { - return &discovery.OrganizationList, false, nil +// cache is set to true if there should be no network call done +func (discovery *Discovery) Organizations(ctx context.Context, cache bool) (*Organizations, bool, error) { + if cache || !discovery.DetermineOrganizationsUpdate() { + return discovery.cachedOrgs(), false, nil } file := "organization_list.json" var jsonDecode Organizations @@ -347,11 +323,7 @@ func (discovery *Discovery) Organizations(ctx context.Context) (*Organizations, } } // Return previous with an error - orgs, perr := discovery.previousOrganizations() - if perr != nil { - slog.Warn("failed to get previous discovery organizations", "error", perr) - } - return orgs, false, err + return discovery.cachedOrgs(), false, err } if len(jsonDecode.List) == 0 { slog.Warn("fresh organization list is empty") @@ -365,12 +337,20 @@ func (discovery *Discovery) Organizations(ctx context.Context) (*Organizations, return &discovery.OrganizationList, true, nil } +func (discovery *Discovery) cachedServers() *Servers { + if discovery.ServerList.Version == 0 { + return nil + } + return &discovery.ServerList +} + // Servers returns the discovery servers // The second return value is a boolean that indicates whether a fresh list was updated internally // If there was an error, a cached copy is returned if available. -func (discovery *Discovery) Servers(ctx context.Context) (*Servers, bool, error) { - if !discovery.DetermineServersUpdate() { - return &discovery.ServerList, false, nil +// cache is set to true if there should be no network call done +func (discovery *Discovery) Servers(ctx context.Context, cache bool) (*Servers, bool, error) { + if cache || !discovery.DetermineServersUpdate() { + return discovery.cachedServers(), false, nil } file := "server_list.json" var jsonDecode Servers @@ -387,11 +367,7 @@ func (discovery *Discovery) Servers(ctx context.Context) (*Servers, bool, error) } } // Return previous with an error - srvs, perr := discovery.previousServers() - if perr != nil { - slog.Warn("failed to get previous discovery server", "error", perr) - } - return srvs, false, err + return discovery.cachedServers(), false, err } if len(jsonDecode.List) == 0 { slog.Warn("fresh server list is empty") @@ -407,9 +383,17 @@ func (discovery *Discovery) Servers(ctx context.Context) (*Servers, bool, error) // UpdateServers updates the discovery servers to the new version // It does this by checking versions -func (discovery *Discovery) UpdateServers(other Discovery) { - if other.ServerList.Version >= discovery.ServerList.Version { - discovery.ServerList = other.ServerList +func (discovery *Discovery) UpdateServers(other Servers) { + if other.Version >= discovery.ServerList.Version { + discovery.ServerList = other + } +} + +// UpdateOrganizations updates the discovery organizations to the new version +// It does this by checking versions +func (discovery *Discovery) UpdateOrganizations(other Organizations) { + if other.Version >= discovery.OrganizationList.Version { + discovery.OrganizationList = other } } @@ -429,3 +413,19 @@ func (discovery *Discovery) Copy() (Discovery, error) { return dest, nil } + +// Fill makes sure that the cache is filled with the embedded discovery +func (discovery *Discovery) Fill() error { + if !HasCache { + return nil + } + + var err error + var es Servers + err = errors.Join(err, json.Unmarshal(eServers, &es)) + discovery.UpdateServers(es) + var eo Organizations + err = errors.Join(err, json.Unmarshal(eOrganizations, &eo)) + discovery.UpdateOrganizations(eo) + return err +} diff --git a/internal/discovery/discovery_test.go b/internal/discovery/discovery_test.go index 08c5ef8..802123a 100644 --- a/internal/discovery/discovery_test.go +++ b/internal/discovery/discovery_test.go @@ -23,7 +23,7 @@ func TestServers(t *testing.T) { } d := &Discovery{httpClient: c} // get servers - _, fresh, err := d.Servers(context.Background()) + _, fresh, err := d.Servers(context.Background(), false) if !fresh { t.Fatalf("Did not obtain the server list fresh") } @@ -51,7 +51,7 @@ func TestServers(t *testing.T) { SupportContact: []string{"mailto:test@example.org"}, } // conditional requests: this should not be fetched fresh - _, fresh, err = d.Servers(context.Background()) + _, fresh, err = d.Servers(context.Background(), false) if fresh { t.Fatalf("Obtained the server list fresh with conditional requests") } @@ -60,7 +60,7 @@ func TestServers(t *testing.T) { } // mock conditional requests d.ServerList.UpdateHeader = time.Time{} - s1, fresh, err := d.Servers(context.Background()) + s1, fresh, err := d.Servers(context.Background(), false) if !fresh { t.Fatalf("Did not obtain the server list fresh with mocked conditional request and mocked entry") } @@ -74,7 +74,7 @@ func TestServers(t *testing.T) { // Shutdown the server s.Close() // Test if we get the same cached copy - s2, fresh, err := d.Servers(context.Background()) + s2, fresh, err := d.Servers(context.Background(), false) // We should not get an error as the timestamp is not expired if fresh { t.Fatalf("The server list was obtained fresh") @@ -90,7 +90,7 @@ func TestServers(t *testing.T) { // we should return the previous with an error d.ServerList.Timestamp = time.Now().Add(-1 * time.Hour) - s3, fresh, err := d.Servers(context.Background()) + s3, fresh, err := d.Servers(context.Background(), false) if fresh { t.Fatalf("Server list was gotten fresh") } @@ -115,7 +115,7 @@ func TestOrganizations(t *testing.T) { } d := &Discovery{httpClient: c} // get servers - _, fresh, err := d.Organizations(context.Background()) + _, fresh, err := d.Organizations(context.Background(), false) if !fresh { t.Fatalf("The organization list was not obtained fresh") } @@ -145,7 +145,7 @@ func TestOrganizations(t *testing.T) { }, } - _, fresh, err = d.Organizations(context.Background()) + _, fresh, err = d.Organizations(context.Background(), false) if fresh { t.Fatalf("Obtained the organization list fresh with conditional requests") } @@ -154,7 +154,7 @@ func TestOrganizations(t *testing.T) { } // mock conditional requests d.OrganizationList.UpdateHeader = time.Time{} - s1, fresh, err := d.Organizations(context.Background()) + s1, fresh, err := d.Organizations(context.Background(), false) if !fresh { t.Fatalf("Did not obtain the organization list fresh after inserting a mock entry, faking expiry and mocking conditional request") } @@ -169,7 +169,7 @@ func TestOrganizations(t *testing.T) { s.Close() // Test if we get the same cached copy // We should not get an error as the timestamp is not zero - s2, fresh, err := d.Organizations(context.Background()) + s2, fresh, err := d.Organizations(context.Background(), false) if fresh { t.Fatalf("The organization list is freshly obtained") } diff --git a/internal/discovery/manager.go b/internal/discovery/manager.go index 134525b..4fb4f8e 100644 --- a/internal/discovery/manager.go +++ b/internal/discovery/manager.go @@ -72,10 +72,10 @@ func (m *Manager) Startup(ctx context.Context, cb func()) { } m.unlock(false) // we already log the warning - discoCopy.Servers(ctx) //nolint:errcheck + discoCopy.Servers(ctx, false) //nolint:errcheck m.lock(true) - m.disco.UpdateServers(discoCopy) + m.disco.UpdateServers(discoCopy.ServerList) m.unlock(true) m.wait.Done() diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index e0d081a..69b1e97 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -67,8 +67,8 @@ func (s *Servers) AddSecure(ctx context.Context, discom *discovery.Manager, orgI defer release() // the only thing we can do is log warn // this is already done in the functions - newd.Servers(ctx) //nolint:errcheck - newd.Organizations(ctx) //nolint:errcheck + newd.Servers(ctx, false) //nolint:errcheck + newd.Organizations(ctx, false) //nolint:errcheck updorg, updsrv, err := newd.SecureHomeArgs(orgID) if err != nil { return "", err @@ -146,9 +146,9 @@ func (s *Servers) GetSecure(ctx context.Context, orgID string, discom *discovery // the only thing we can do is log warn // this is already done in the functions newd.MarkServersExpired() - newd.Servers(ctx) //nolint:errcheck + newd.Servers(ctx, false) //nolint:errcheck newd.MarkOrganizationsExpired() - newd.Organizations(ctx) //nolint:errcheck + newd.Organizations(ctx, false) //nolint:errcheck updorg, updsrv, err := newd.SecureHomeArgs(orgID) if err != nil { return "", err -- cgit v1.2.3