summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/discovery/discovery.go96
-rw-r--r--internal/discovery/discovery_test.go28
-rw-r--r--internal/http/http.go5
3 files changed, 102 insertions, 27 deletions
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index ad083b3..30ca801 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -4,10 +4,12 @@ package discovery
import (
"context"
"encoding/json"
+ "errors"
+ "net/http"
"fmt"
"time"
- "github.com/eduvpn/eduvpn-common/internal/http"
+ httpw "github.com/eduvpn/eduvpn-common/internal/http"
"github.com/eduvpn/eduvpn-common/internal/levenshtein"
"github.com/eduvpn/eduvpn-common/internal/log"
"github.com/eduvpn/eduvpn-common/internal/verify"
@@ -26,6 +28,8 @@ type Organizations struct {
// Timestamp is the timestamp that is internally used by the Go library to keep track
// of when the organizations were last updated
Timestamp time.Time `json:"go_timestamp"`
+ // UpdateHeader is the result of the "Last-Modified" header
+ UpdateHeader time.Time `json:"go_update_header"`
}
// Organization is a single discovery Organization
@@ -53,6 +57,8 @@ type Servers struct {
// Timestamp is a timestamp that is internally used by the Go library to keek track
// of when the servers were last updated
Timestamp time.Time `json:"go_timestamp"`
+ // UpdateHeader is the result of the "Last-Modified" header
+ UpdateHeader time.Time `json:"go_update_header"`
}
// Server is a single discovery server
@@ -77,7 +83,7 @@ func (s *Server) Score(search string) int {
// Discovery is the main structure used for this package.
type Discovery struct {
// The httpClient for sending HTTP requests
- httpClient *http.Client
+ httpClient *httpw.Client
// Organizations represents the organizations that are returned by the discovery server
OrganizationList Organizations `json:"organizations"`
@@ -91,31 +97,55 @@ var DiscoURL = "https://disco.eduvpn.org/v2/"
// file is a helper function that gets a disco JSON and fills the structure with it
// If it was unsuccessful it returns an error.
-func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousVersion uint64, structure interface{}) error {
+func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousVersion uint64, last time.Time, structure interface{}) (time.Time, error) {
+ var newUpdate time.Time
// No HTTP client present, create one
if discovery.httpClient == nil {
- discovery.httpClient = http.NewClient(nil)
+ discovery.httpClient = httpw.NewClient(nil)
}
// Get json data
- jsonURL, err := http.JoinURLPath(DiscoURL, jsonFile)
+ jsonURL, err := httpw.JoinURLPath(DiscoURL, jsonFile)
if err != nil {
- return err
+ return newUpdate, err
}
- _, body, err := discovery.httpClient.Get(ctx, jsonURL)
+
+ var opts *httpw.OptionalParams
+ if !last.IsZero() {
+ header := http.Header{
+ "If-Modified-Since": []string{last.Format(http.TimeFormat)},
+ }
+ opts = &httpw.OptionalParams{
+ Headers: header,
+ }
+ }
+ h, body, err := discovery.httpClient.Do(ctx, "GET", jsonURL, opts)
if err != nil {
- return err
+ return newUpdate, err
+ }
+
+ lms := h.Get("Last-Modified")
+ if lms != "" {
+ lm, err := http.ParseTime(lms)
+ if err != nil {
+ log.Logger.Warningf("failed to parse 'Last-Modified' header: %v", err)
+ } else {
+ newUpdate = lm
+ log.Logger.Debugf("got 'Last-Modified' header: %v", lm)
+ }
+ } else {
+ log.Logger.Warningf("no 'Last-Modified' header found")
}
// Get signature
sigFile := jsonFile + ".minisig"
- sigURL, err := http.JoinURLPath(DiscoURL, sigFile)
+ sigURL, err := httpw.JoinURLPath(DiscoURL, sigFile)
if err != nil {
- return err
+ return newUpdate, err
}
_, sigBody, err := discovery.httpClient.Get(ctx, sigURL)
if err != nil {
- return err
+ return newUpdate, err
}
// Verify signature
@@ -130,15 +160,15 @@ func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousV
)
if !ok || err != nil {
- return err
+ return newUpdate, err
}
// Parse JSON to extract version and list
if err = json.Unmarshal(body, structure); err != nil {
- return fmt.Errorf("failed parsing discovery file: '%s' from the server with error: %w", jsonFile, err)
+ return newUpdate, fmt.Errorf("failed parsing discovery file: '%s' from the server with error: %w", jsonFile, err)
}
- return nil
+ return newUpdate, nil
}
// MarkOrganizationsExpired marks the organizations as expired
@@ -163,6 +193,9 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool {
if discovery.OrganizationList.Timestamp.IsZero() {
return true
}
+ if discovery.OrganizationList.UpdateHeader.IsZero() {
+ return true
+ }
// 4 hour since the last update
upd := discovery.OrganizationList.Timestamp.Add(4 * time.Hour)
return !time.Now().Before(upd)
@@ -253,6 +286,9 @@ func (discovery *Discovery) DetermineServersUpdate() bool {
if discovery.ServerList.Timestamp.IsZero() {
return true
}
+ if discovery.ServerList.UpdateHeader.IsZero() {
+ return true
+ }
// 1 hour from the last update
upd := discovery.ServerList.Timestamp.Add(1 * time.Hour)
return !time.Now().Before(upd)
@@ -299,9 +335,18 @@ func (discovery *Discovery) Organizations(ctx context.Context) (*Organizations,
}
file := "organization_list.json"
var jsonDecode Organizations
- err := discovery.file(ctx, file, discovery.OrganizationList.Version, &jsonDecode)
+ update, err := discovery.file(ctx, file, discovery.OrganizationList.Version, discovery.OrganizationList.UpdateHeader, &jsonDecode)
if err != nil {
- log.Logger.Warningf("failed to get fresh organizations: %v", err)
+ statErr := &httpw.StatusError{}
+ if errors.As(err, &statErr) {
+ if statErr.Status != 304 {
+ log.Logger.Warningf("failed to get fresh organizations: %v", err)
+ } else {
+ discovery.OrganizationList.Timestamp = time.Now()
+ log.Logger.Debugf("got 304 for discovery, organization_list.json not modified")
+ err = nil
+ }
+ }
// Return previous with an error
orgs, perr := discovery.previousOrganizations()
if perr != nil {
@@ -315,6 +360,9 @@ func (discovery *Discovery) Organizations(ctx context.Context) (*Organizations,
discovery.OrganizationList = jsonDecode
}
discovery.OrganizationList.Timestamp = time.Now()
+ if !update.IsZero() {
+ discovery.OrganizationList.UpdateHeader = update
+ }
return &discovery.OrganizationList, true, nil
}
@@ -327,9 +375,18 @@ func (discovery *Discovery) Servers(ctx context.Context) (*Servers, bool, error)
}
file := "server_list.json"
var jsonDecode Servers
- err := discovery.file(ctx, file, discovery.ServerList.Version, &jsonDecode)
+ update, err := discovery.file(ctx, file, discovery.ServerList.Version, discovery.ServerList.UpdateHeader, &jsonDecode)
if err != nil {
- log.Logger.Warningf("failed to get fresh servers: %v", err)
+ statErr := &httpw.StatusError{}
+ if errors.As(err, &statErr) {
+ if statErr.Status != 304 {
+ log.Logger.Warningf("failed to get fresh servers: %v", err)
+ } else {
+ discovery.ServerList.Timestamp = time.Now()
+ log.Logger.Debugf("got 304 for discovery, server_list.json not modified")
+ err = nil
+ }
+ }
// Return previous with an error
srvs, perr := discovery.previousServers()
if perr != nil {
@@ -343,5 +400,8 @@ func (discovery *Discovery) Servers(ctx context.Context) (*Servers, bool, error)
discovery.ServerList = jsonDecode
}
discovery.ServerList.Timestamp = time.Now()
+ if !update.IsZero() {
+ discovery.ServerList.UpdateHeader = update
+ }
return &discovery.ServerList, true, nil
}
diff --git a/internal/discovery/discovery_test.go b/internal/discovery/discovery_test.go
index 1101bb2..5672b9f 100644
--- a/internal/discovery/discovery_test.go
+++ b/internal/discovery/discovery_test.go
@@ -50,12 +50,22 @@ func TestServers(t *testing.T) {
},
SupportContact: []string{"mailto:test@example.org"},
}
+ // conditional requests: this should not be fetched fresh
+ _, fresh, err = d.Servers(context.Background())
+ if fresh {
+ t.Fatalf("Obtained the server list fresh with conditional requests")
+ }
+ if err != nil {
+ t.Fatalf("Failed getting servers after inserting a mock entry: %v", err)
+ }
+ // mock conditional requests
+ d.ServerList.UpdateHeader = time.Time{}
s1, fresh, err := d.Servers(context.Background())
if !fresh {
- t.Fatalf("Did not obtain the server list fresh after inserting a mock entry and faking expiry")
+ t.Fatalf("Did not obtain the server list fresh with mocked conditional request and mocked entry")
}
if err != nil {
- t.Fatalf("Failed getting servers after inserting a mock entry: %v", err)
+ t.Fatalf("Failed getting servers after inserting a mock entry and mocking conditional request: %v", err)
}
if kws := s1.List[len(s1.List)-1].KeywordList; kws != nil {
t.Fatalf("KeywordList is not nil when getting a fresh server list after inserting a mock entry: %v", kws)
@@ -134,12 +144,22 @@ func TestOrganizations(t *testing.T) {
"en": "test bla",
},
}
+
+ _, fresh, err = d.Organizations(context.Background())
+ if fresh {
+ t.Fatalf("Obtained the organization list fresh with conditional requests")
+ }
+ if err != nil {
+ t.Fatalf("Failed getting organizations after inserting a mock entry: %v", err)
+ }
+ // mock conditional requests
+ d.OrganizationList.UpdateHeader = time.Time{}
s1, fresh, err := d.Organizations(context.Background())
if !fresh {
- t.Fatalf("Did not obtain the organization list fresh after inserting a mock entry and faking expiry")
+ t.Fatalf("Did not obtain the organization list fresh after inserting a mock entry, faking expiry and mocking conditional request")
}
if err != nil {
- t.Fatalf("Failed getting organizations after inserting a mock entry: %v", err)
+ t.Fatalf("Failed getting organizations after inserting a mock entry and faking conditional request: %v", err)
}
if kws := s1.List[len(s1.List)-1].KeywordList; kws != nil {
t.Fatalf("KeywordList is not nil when getting a fresh organization list after inserting a mock entry: %v", kws)
diff --git a/internal/http/http.go b/internal/http/http.go
index 09f1953..196998b 100644
--- a/internal/http/http.go
+++ b/internal/http/http.go
@@ -164,11 +164,6 @@ func (c *Client) Get(ctx context.Context, url string) (http.Header, []byte, erro
return c.Do(ctx, http.MethodGet, url, nil)
}
-// PostWithOpts creates a Post request with optional parameters and returns the headers, body and an error.
-func (c *Client) PostWithOpts(ctx context.Context, url string, opts *OptionalParams) (http.Header, []byte, error) {
- return c.Do(ctx, http.MethodPost, url, opts)
-}
-
// Do sends a HTTP request using a method (e.g. GET, POST), an url and optional parameters
// It returns the HTTP headers, the body and an error if there is one.
func (c *Client) Do(ctx context.Context, method string, urlStr string, opts *OptionalParams) (http.Header, []byte, error) {