summaryrefslogtreecommitdiff
path: root/internal/discovery
diff options
context:
space:
mode:
Diffstat (limited to 'internal/discovery')
-rw-r--r--internal/discovery/discovery.go15
-rw-r--r--internal/discovery/discovery_test.go11
2 files changed, 14 insertions, 12 deletions
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index ae7a307..06548f9 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -2,6 +2,7 @@
package discovery
import (
+ "context"
"encoding/json"
"fmt"
"time"
@@ -31,7 +32,7 @@ var DiscoURL = "https://disco.eduvpn.org/v2/"
// file is a helper function that gets a disco JSON and fills the structure with it
// If it was unsuccessful it returns an error.
-func (discovery *Discovery) file(jsonFile string, previousVersion uint64, structure interface{}) error {
+func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousVersion uint64, structure interface{}) error {
// No HTTP client present, create one
if discovery.httpClient == nil {
discovery.httpClient = http.NewClient()
@@ -42,7 +43,7 @@ func (discovery *Discovery) file(jsonFile string, previousVersion uint64, struct
if err != nil {
return err
}
- _, body, err := discovery.httpClient.Get(jsonURL)
+ _, body, err := discovery.httpClient.Get(ctx, jsonURL)
if err != nil {
return err
}
@@ -53,7 +54,7 @@ func (discovery *Discovery) file(jsonFile string, previousVersion uint64, struct
if err != nil {
return err
}
- _, sigBody, err := discovery.httpClient.Get(sigURL)
+ _, sigBody, err := discovery.httpClient.Get(ctx, sigURL)
if err != nil {
return err
}
@@ -212,12 +213,12 @@ func (discovery *Discovery) previousServers() (*discotypes.Servers, error) {
// Organizations returns the discovery organizations
// If there was an error, a cached copy is returned if available.
-func (discovery *Discovery) Organizations() (*discotypes.Organizations, error) {
+func (discovery *Discovery) Organizations(ctx context.Context) (*discotypes.Organizations, error) {
if !discovery.DetermineOrganizationsUpdate() {
return &discovery.OrganizationList, nil
}
file := "organization_list.json"
- err := discovery.file(file, discovery.OrganizationList.Version, &discovery.OrganizationList)
+ err := discovery.file(ctx, file, discovery.OrganizationList.Version, &discovery.OrganizationList)
if err != nil {
// Return previous with an error
// TODO: Log here if we fail to get previous
@@ -230,12 +231,12 @@ func (discovery *Discovery) Organizations() (*discotypes.Organizations, error) {
// Servers returns the discovery servers
// If there was an error, a cached copy is returned if available.
-func (discovery *Discovery) Servers() (*discotypes.Servers, error) {
+func (discovery *Discovery) Servers(ctx context.Context) (*discotypes.Servers, error) {
if !discovery.DetermineServersUpdate() {
return &discovery.ServerList, nil
}
file := "server_list.json"
- err := discovery.file(file, discovery.ServerList.Version, &discovery.ServerList)
+ err := discovery.file(ctx, file, discovery.ServerList.Version, &discovery.ServerList)
if err != nil {
// Return previous with an error
// TODO: Log here if we fail to get previous
diff --git a/internal/discovery/discovery_test.go b/internal/discovery/discovery_test.go
index 93ab51e..317aa50 100644
--- a/internal/discovery/discovery_test.go
+++ b/internal/discovery/discovery_test.go
@@ -1,6 +1,7 @@
package discovery
import (
+ "context"
"net/http"
"reflect"
"testing"
@@ -22,7 +23,7 @@ func TestServers(t *testing.T) {
}
d := &Discovery{httpClient: c}
// get servers
- s1, err := d.Servers()
+ s1, err := d.Servers(context.Background())
if err != nil {
t.Fatalf("Failed getting servers: %v", err)
}
@@ -30,7 +31,7 @@ func TestServers(t *testing.T) {
// Shutdown the server
s.Close()
// Test if we get the same cached copy
- s2, err := d.Servers()
+ s2, err := d.Servers(context.Background())
// We should not get an error as the timestamp is not expired
if err != nil {
t.Fatalf("Got a servers error after shutting down server: %v", err)
@@ -42,7 +43,7 @@ func TestServers(t *testing.T) {
// Force expired, 1 hour in the past
d.ServerList.Timestamp = time.Now().Add(-1 * time.Hour)
- s3, err := d.Servers()
+ s3, err := d.Servers(context.Background())
// Now we expect an error with the cached copy
if err == nil {
t.Fatalf("Got a servers nil error after shutting down file server and expired")
@@ -64,7 +65,7 @@ func TestOrganizations(t *testing.T) {
}
d := &Discovery{httpClient: c}
// get servers
- s1, err := d.Organizations()
+ s1, err := d.Organizations(context.Background())
if err != nil {
t.Fatalf("Failed getting organizations: %v", err)
}
@@ -73,7 +74,7 @@ func TestOrganizations(t *testing.T) {
s.Close()
// Test if we get the same cached copy
// We should not get an error as the timestamp is not zero
- s2, err := d.Organizations()
+ s2, err := d.Organizations(context.Background())
if err != nil {
t.Fatalf("Got an organizations error after shutting down file server: %v", err)
}