diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2023-04-12 22:46:54 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2023-09-25 09:43:37 +0200 |
| commit | 1e54063813efb6e822df36ad39d7f889a7f2e38b (patch) | |
| tree | 296878338a86584434f2e4ed3317a874f0ee37ab /internal/discovery | |
| parent | 3a62e2c16cf0f663595a9b0cb658ef2a060c9511 (diff) | |
Discovery: Pass a context around
Diffstat (limited to 'internal/discovery')
| -rw-r--r-- | internal/discovery/discovery.go | 15 | ||||
| -rw-r--r-- | internal/discovery/discovery_test.go | 11 |
2 files changed, 14 insertions, 12 deletions
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index ae7a307..06548f9 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -2,6 +2,7 @@ package discovery import ( + "context" "encoding/json" "fmt" "time" @@ -31,7 +32,7 @@ var DiscoURL = "https://disco.eduvpn.org/v2/" // file is a helper function that gets a disco JSON and fills the structure with it // If it was unsuccessful it returns an error. -func (discovery *Discovery) file(jsonFile string, previousVersion uint64, structure interface{}) error { +func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousVersion uint64, structure interface{}) error { // No HTTP client present, create one if discovery.httpClient == nil { discovery.httpClient = http.NewClient() @@ -42,7 +43,7 @@ func (discovery *Discovery) file(jsonFile string, previousVersion uint64, struct if err != nil { return err } - _, body, err := discovery.httpClient.Get(jsonURL) + _, body, err := discovery.httpClient.Get(ctx, jsonURL) if err != nil { return err } @@ -53,7 +54,7 @@ func (discovery *Discovery) file(jsonFile string, previousVersion uint64, struct if err != nil { return err } - _, sigBody, err := discovery.httpClient.Get(sigURL) + _, sigBody, err := discovery.httpClient.Get(ctx, sigURL) if err != nil { return err } @@ -212,12 +213,12 @@ func (discovery *Discovery) previousServers() (*discotypes.Servers, error) { // Organizations returns the discovery organizations // If there was an error, a cached copy is returned if available. -func (discovery *Discovery) Organizations() (*discotypes.Organizations, error) { +func (discovery *Discovery) Organizations(ctx context.Context) (*discotypes.Organizations, error) { if !discovery.DetermineOrganizationsUpdate() { return &discovery.OrganizationList, nil } file := "organization_list.json" - err := discovery.file(file, discovery.OrganizationList.Version, &discovery.OrganizationList) + err := discovery.file(ctx, file, discovery.OrganizationList.Version, &discovery.OrganizationList) if err != nil { // Return previous with an error // TODO: Log here if we fail to get previous @@ -230,12 +231,12 @@ func (discovery *Discovery) Organizations() (*discotypes.Organizations, error) { // Servers returns the discovery servers // If there was an error, a cached copy is returned if available. -func (discovery *Discovery) Servers() (*discotypes.Servers, error) { +func (discovery *Discovery) Servers(ctx context.Context) (*discotypes.Servers, error) { if !discovery.DetermineServersUpdate() { return &discovery.ServerList, nil } file := "server_list.json" - err := discovery.file(file, discovery.ServerList.Version, &discovery.ServerList) + err := discovery.file(ctx, file, discovery.ServerList.Version, &discovery.ServerList) if err != nil { // Return previous with an error // TODO: Log here if we fail to get previous diff --git a/internal/discovery/discovery_test.go b/internal/discovery/discovery_test.go index 93ab51e..317aa50 100644 --- a/internal/discovery/discovery_test.go +++ b/internal/discovery/discovery_test.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "net/http" "reflect" "testing" @@ -22,7 +23,7 @@ func TestServers(t *testing.T) { } d := &Discovery{httpClient: c} // get servers - s1, err := d.Servers() + s1, err := d.Servers(context.Background()) if err != nil { t.Fatalf("Failed getting servers: %v", err) } @@ -30,7 +31,7 @@ func TestServers(t *testing.T) { // Shutdown the server s.Close() // Test if we get the same cached copy - s2, err := d.Servers() + s2, err := d.Servers(context.Background()) // We should not get an error as the timestamp is not expired if err != nil { t.Fatalf("Got a servers error after shutting down server: %v", err) @@ -42,7 +43,7 @@ func TestServers(t *testing.T) { // Force expired, 1 hour in the past d.ServerList.Timestamp = time.Now().Add(-1 * time.Hour) - s3, err := d.Servers() + s3, err := d.Servers(context.Background()) // Now we expect an error with the cached copy if err == nil { t.Fatalf("Got a servers nil error after shutting down file server and expired") @@ -64,7 +65,7 @@ func TestOrganizations(t *testing.T) { } d := &Discovery{httpClient: c} // get servers - s1, err := d.Organizations() + s1, err := d.Organizations(context.Background()) if err != nil { t.Fatalf("Failed getting organizations: %v", err) } @@ -73,7 +74,7 @@ func TestOrganizations(t *testing.T) { s.Close() // Test if we get the same cached copy // We should not get an error as the timestamp is not zero - s2, err := d.Organizations() + s2, err := d.Organizations(context.Background()) if err != nil { t.Fatalf("Got an organizations error after shutting down file server: %v", err) } |
