diff options
| -rw-r--r-- | internal/discovery/discovery.go | 96 | ||||
| -rw-r--r-- | internal/discovery/discovery_test.go | 28 | ||||
| -rw-r--r-- | internal/http/http.go | 5 |
3 files changed, 102 insertions, 27 deletions
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index ad083b3..30ca801 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -4,10 +4,12 @@ package discovery import ( "context" "encoding/json" + "errors" + "net/http" "fmt" "time" - "github.com/eduvpn/eduvpn-common/internal/http" + httpw "github.com/eduvpn/eduvpn-common/internal/http" "github.com/eduvpn/eduvpn-common/internal/levenshtein" "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/internal/verify" @@ -26,6 +28,8 @@ type Organizations struct { // Timestamp is the timestamp that is internally used by the Go library to keep track // of when the organizations were last updated Timestamp time.Time `json:"go_timestamp"` + // UpdateHeader is the result of the "Last-Modified" header + UpdateHeader time.Time `json:"go_update_header"` } // Organization is a single discovery Organization @@ -53,6 +57,8 @@ type Servers struct { // Timestamp is a timestamp that is internally used by the Go library to keek track // of when the servers were last updated Timestamp time.Time `json:"go_timestamp"` + // UpdateHeader is the result of the "Last-Modified" header + UpdateHeader time.Time `json:"go_update_header"` } // Server is a single discovery server @@ -77,7 +83,7 @@ func (s *Server) Score(search string) int { // Discovery is the main structure used for this package. type Discovery struct { // The httpClient for sending HTTP requests - httpClient *http.Client + httpClient *httpw.Client // Organizations represents the organizations that are returned by the discovery server OrganizationList Organizations `json:"organizations"` @@ -91,31 +97,55 @@ var DiscoURL = "https://disco.eduvpn.org/v2/" // file is a helper function that gets a disco JSON and fills the structure with it // If it was unsuccessful it returns an error. -func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousVersion uint64, structure interface{}) error { +func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousVersion uint64, last time.Time, structure interface{}) (time.Time, error) { + var newUpdate time.Time // No HTTP client present, create one if discovery.httpClient == nil { - discovery.httpClient = http.NewClient(nil) + discovery.httpClient = httpw.NewClient(nil) } // Get json data - jsonURL, err := http.JoinURLPath(DiscoURL, jsonFile) + jsonURL, err := httpw.JoinURLPath(DiscoURL, jsonFile) if err != nil { - return err + return newUpdate, err } - _, body, err := discovery.httpClient.Get(ctx, jsonURL) + + var opts *httpw.OptionalParams + if !last.IsZero() { + header := http.Header{ + "If-Modified-Since": []string{last.Format(http.TimeFormat)}, + } + opts = &httpw.OptionalParams{ + Headers: header, + } + } + h, body, err := discovery.httpClient.Do(ctx, "GET", jsonURL, opts) if err != nil { - return err + return newUpdate, err + } + + lms := h.Get("Last-Modified") + if lms != "" { + lm, err := http.ParseTime(lms) + if err != nil { + log.Logger.Warningf("failed to parse 'Last-Modified' header: %v", err) + } else { + newUpdate = lm + log.Logger.Debugf("got 'Last-Modified' header: %v", lm) + } + } else { + log.Logger.Warningf("no 'Last-Modified' header found") } // Get signature sigFile := jsonFile + ".minisig" - sigURL, err := http.JoinURLPath(DiscoURL, sigFile) + sigURL, err := httpw.JoinURLPath(DiscoURL, sigFile) if err != nil { - return err + return newUpdate, err } _, sigBody, err := discovery.httpClient.Get(ctx, sigURL) if err != nil { - return err + return newUpdate, err } // Verify signature @@ -130,15 +160,15 @@ func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousV ) if !ok || err != nil { - return err + return newUpdate, err } // Parse JSON to extract version and list if err = json.Unmarshal(body, structure); err != nil { - return fmt.Errorf("failed parsing discovery file: '%s' from the server with error: %w", jsonFile, err) + return newUpdate, fmt.Errorf("failed parsing discovery file: '%s' from the server with error: %w", jsonFile, err) } - return nil + return newUpdate, nil } // MarkOrganizationsExpired marks the organizations as expired @@ -163,6 +193,9 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool { if discovery.OrganizationList.Timestamp.IsZero() { return true } + if discovery.OrganizationList.UpdateHeader.IsZero() { + return true + } // 4 hour since the last update upd := discovery.OrganizationList.Timestamp.Add(4 * time.Hour) return !time.Now().Before(upd) @@ -253,6 +286,9 @@ func (discovery *Discovery) DetermineServersUpdate() bool { if discovery.ServerList.Timestamp.IsZero() { return true } + if discovery.ServerList.UpdateHeader.IsZero() { + return true + } // 1 hour from the last update upd := discovery.ServerList.Timestamp.Add(1 * time.Hour) return !time.Now().Before(upd) @@ -299,9 +335,18 @@ func (discovery *Discovery) Organizations(ctx context.Context) (*Organizations, } file := "organization_list.json" var jsonDecode Organizations - err := discovery.file(ctx, file, discovery.OrganizationList.Version, &jsonDecode) + update, err := discovery.file(ctx, file, discovery.OrganizationList.Version, discovery.OrganizationList.UpdateHeader, &jsonDecode) if err != nil { - log.Logger.Warningf("failed to get fresh organizations: %v", err) + statErr := &httpw.StatusError{} + if errors.As(err, &statErr) { + if statErr.Status != 304 { + log.Logger.Warningf("failed to get fresh organizations: %v", err) + } else { + discovery.OrganizationList.Timestamp = time.Now() + log.Logger.Debugf("got 304 for discovery, organization_list.json not modified") + err = nil + } + } // Return previous with an error orgs, perr := discovery.previousOrganizations() if perr != nil { @@ -315,6 +360,9 @@ func (discovery *Discovery) Organizations(ctx context.Context) (*Organizations, discovery.OrganizationList = jsonDecode } discovery.OrganizationList.Timestamp = time.Now() + if !update.IsZero() { + discovery.OrganizationList.UpdateHeader = update + } return &discovery.OrganizationList, true, nil } @@ -327,9 +375,18 @@ func (discovery *Discovery) Servers(ctx context.Context) (*Servers, bool, error) } file := "server_list.json" var jsonDecode Servers - err := discovery.file(ctx, file, discovery.ServerList.Version, &jsonDecode) + update, err := discovery.file(ctx, file, discovery.ServerList.Version, discovery.ServerList.UpdateHeader, &jsonDecode) if err != nil { - log.Logger.Warningf("failed to get fresh servers: %v", err) + statErr := &httpw.StatusError{} + if errors.As(err, &statErr) { + if statErr.Status != 304 { + log.Logger.Warningf("failed to get fresh servers: %v", err) + } else { + discovery.ServerList.Timestamp = time.Now() + log.Logger.Debugf("got 304 for discovery, server_list.json not modified") + err = nil + } + } // Return previous with an error srvs, perr := discovery.previousServers() if perr != nil { @@ -343,5 +400,8 @@ func (discovery *Discovery) Servers(ctx context.Context) (*Servers, bool, error) discovery.ServerList = jsonDecode } discovery.ServerList.Timestamp = time.Now() + if !update.IsZero() { + discovery.ServerList.UpdateHeader = update + } return &discovery.ServerList, true, nil } diff --git a/internal/discovery/discovery_test.go b/internal/discovery/discovery_test.go index 1101bb2..5672b9f 100644 --- a/internal/discovery/discovery_test.go +++ b/internal/discovery/discovery_test.go @@ -50,12 +50,22 @@ func TestServers(t *testing.T) { }, SupportContact: []string{"mailto:test@example.org"}, } + // conditional requests: this should not be fetched fresh + _, fresh, err = d.Servers(context.Background()) + if fresh { + t.Fatalf("Obtained the server list fresh with conditional requests") + } + if err != nil { + t.Fatalf("Failed getting servers after inserting a mock entry: %v", err) + } + // mock conditional requests + d.ServerList.UpdateHeader = time.Time{} s1, fresh, err := d.Servers(context.Background()) if !fresh { - t.Fatalf("Did not obtain the server list fresh after inserting a mock entry and faking expiry") + t.Fatalf("Did not obtain the server list fresh with mocked conditional request and mocked entry") } if err != nil { - t.Fatalf("Failed getting servers after inserting a mock entry: %v", err) + t.Fatalf("Failed getting servers after inserting a mock entry and mocking conditional request: %v", err) } if kws := s1.List[len(s1.List)-1].KeywordList; kws != nil { t.Fatalf("KeywordList is not nil when getting a fresh server list after inserting a mock entry: %v", kws) @@ -134,12 +144,22 @@ func TestOrganizations(t *testing.T) { "en": "test bla", }, } + + _, fresh, err = d.Organizations(context.Background()) + if fresh { + t.Fatalf("Obtained the organization list fresh with conditional requests") + } + if err != nil { + t.Fatalf("Failed getting organizations after inserting a mock entry: %v", err) + } + // mock conditional requests + d.OrganizationList.UpdateHeader = time.Time{} s1, fresh, err := d.Organizations(context.Background()) if !fresh { - t.Fatalf("Did not obtain the organization list fresh after inserting a mock entry and faking expiry") + t.Fatalf("Did not obtain the organization list fresh after inserting a mock entry, faking expiry and mocking conditional request") } if err != nil { - t.Fatalf("Failed getting organizations after inserting a mock entry: %v", err) + t.Fatalf("Failed getting organizations after inserting a mock entry and faking conditional request: %v", err) } if kws := s1.List[len(s1.List)-1].KeywordList; kws != nil { t.Fatalf("KeywordList is not nil when getting a fresh organization list after inserting a mock entry: %v", kws) diff --git a/internal/http/http.go b/internal/http/http.go index 09f1953..196998b 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -164,11 +164,6 @@ func (c *Client) Get(ctx context.Context, url string) (http.Header, []byte, erro return c.Do(ctx, http.MethodGet, url, nil) } -// PostWithOpts creates a Post request with optional parameters and returns the headers, body and an error. -func (c *Client) PostWithOpts(ctx context.Context, url string, opts *OptionalParams) (http.Header, []byte, error) { - return c.Do(ctx, http.MethodPost, url, opts) -} - // Do sends a HTTP request using a method (e.g. GET, POST), an url and optional parameters // It returns the HTTP headers, the body and an error if there is one. func (c *Client) Do(ctx context.Context, method string, urlStr string, opts *OptionalParams) (http.Header, []byte, error) { |
