From ff70e291c96de23ae4dab20f9c4e9f895eee53d5 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Fri, 6 Jan 2023 14:53:34 +0100 Subject: Refactor: Re-use a HTTP client --- internal/discovery/discovery.go | 20 +++++++---- internal/http/http.go | 75 +++++++++++++++++++++-------------------- internal/oauth/oauth.go | 19 +++++++++-- internal/server/api.go | 10 ++++-- internal/server/base.go | 2 ++ 5 files changed, 79 insertions(+), 47 deletions(-) diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index b2a90cd..41685ac 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -14,6 +14,9 @@ import ( // Discovery is the main structure used for this package. type Discovery struct { + // The httpClient for sending HTTP requests + httpClient *http.Client + // organizations represents the organizations that are returned by the discovery server organizations types.DiscoveryOrganizations @@ -23,12 +26,17 @@ type Discovery struct { var DiscoURL = "https://disco.eduvpn.org/v2/" -// discoFile is a helper function that gets a disco JSON and fills the structure with it +// file is a helper function that gets a disco JSON and fills the structure with it // If it was unsuccessful it returns an error. -func discoFile(jsonFile string, previousVersion uint64, structure interface{}) error { +func (discovery *Discovery) file(jsonFile string, previousVersion uint64, structure interface{}) error { + // No HTTP client present, create one + if discovery.httpClient == nil { + discovery.httpClient = http.NewClient() + } + // Get json data jsonURL := DiscoURL + jsonFile - _, body, err := http.Get(jsonURL) + _, body, err := discovery.httpClient.Get(jsonURL) if err != nil { return err } @@ -36,7 +44,7 @@ func discoFile(jsonFile string, previousVersion uint64, structure interface{}) e // Get signature sigFile := jsonFile + ".minisig" sigURL := DiscoURL + sigFile - _, sigBody, err := http.Get(sigURL) + _, sigBody, err := discovery.httpClient.Get(sigURL) if err != nil { return err } @@ -162,7 +170,7 @@ func (discovery *Discovery) Organizations() (*types.DiscoveryOrganizations, erro return &discovery.organizations, nil } file := "organization_list.json" - err := discoFile(file, discovery.organizations.Version, &discovery.organizations) + err := discovery.file(file, discovery.organizations.Version, &discovery.organizations) if err != nil { // Return previous with an error return &discovery.organizations, err @@ -178,7 +186,7 @@ func (discovery *Discovery) Servers() (*types.DiscoveryServers, error) { return &discovery.servers, nil } file := "server_list.json" - err := discoFile(file, discovery.servers.Version, &discovery.servers) + err := discovery.file(file, discovery.servers.Version, &discovery.servers) if err != nil { // Return previous with an error return &discovery.servers, err diff --git a/internal/http/http.go b/internal/http/http.go index 81ab822..2512b40 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -20,7 +20,6 @@ type OptionalParams struct { Headers http.Header URLParameters URLParameters Body url.Values - Timeout time.Duration } // ConstructURL creates a URL with the included parameters. @@ -40,16 +39,6 @@ func ConstructURL(baseURL string, params URLParameters) (string, error) { return u.String(), nil } -// Get creates a Get request and returns the headers, body and an error. -func Get(url string) (http.Header, []byte, error) { - return MethodWithOpts(http.MethodGet, url, nil) -} - -// PostWithOpts creates a Post request with optional parameters and returns the headers, body and an error. -func PostWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error) { - return MethodWithOpts(http.MethodPost, url, opts) -} - // optionalURL ensures that the URL contains the optional parameters // it returns the url (with parameters if success) and an error indicating success. func optionalURL(urlStr string, opts *OptionalParams) (string, error) { @@ -78,18 +67,43 @@ func optionalBodyReader(opts *OptionalParams) io.Reader { return nil } -// ReadLimit denotes the maximum amount of bytes that are read in HTTP responses -// This is used to prevent servers from sending huge amounts of data -// A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems -var ReadLimit int64 = 16 << 20 +// Client is a wrapper around http.Client with some convenience features +// - A default timeout of 5 seconds +// - A read limiter to prevent servers from sending large amounts of data +// - Checking on http code with custom errors +type Client struct { + // Client is the HTTP Client that sends the request + Client *http.Client + // ReadLimit denotes the maximum amount of bytes that are read in HTTP responses + // This is used to prevent servers from sending huge amounts of data + // A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems + ReadLimit int64 +} + +// Returns a HTTP client with some default settings +func NewClient() *Client { + // The timeout is 5 seconds by default + c := &http.Client{Timeout: 5 * time.Second} + // ReadLimit denotes the maximum amount of bytes that are read in HTTP responses + // This is used to prevent servers from sending huge amounts of data + // A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems + return &Client{Client: c, ReadLimit: 16 << 20} +} + +// Get creates a Get request and returns the headers, body and an error. +func (c *Client) Get(url string) (http.Header, []byte, error) { + return c.Do(http.MethodGet, url, nil) +} + +// PostWithOpts creates a Post request with optional parameters and returns the headers, body and an error. +func (c *Client) PostWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error) { + return c.Do(http.MethodPost, url, opts) +} -// MethodWithOpts creates a HTTP request using a method (e.g. GET, POST), an url and optional parameters + +// MethodWithOpts Do send a HTTP request using a method (e.g. GET, POST), an url and optional parameters // It returns the HTTP headers, the body and an error if there is one. -func MethodWithOpts( - method string, - urlStr string, - opts *OptionalParams, -) (http.Header, []byte, error) { +func (c *Client) Do(method string, urlStr string, opts *OptionalParams) (http.Header, []byte, error) { // Make sure the url contains all the parameters // This can return an error, // it already has the right error, so we don't wrap it further @@ -99,13 +113,6 @@ func MethodWithOpts( return nil, nil, err } - // Default timeout is 5 seconds - // If a different timeout is given, set it - var timeout time.Duration = 5 - if opts != nil && opts.Timeout > 0 { - timeout = opts.Timeout - } - // Create request object with the body reader generated from the optional arguments req, err := http.NewRequest(method, urlStr, optionalBodyReader(opts)) if err != nil { @@ -113,17 +120,11 @@ func MethodWithOpts( fmt.Sprintf("failed HTTP request with method %s and url %s", method, urlStr), 0) } - // See https://stackoverflow.com/questions/17714494/golang-http-request-results-in-eof-errors-when-making-multiple-requests-successi - req.Close = true - // Make sure the headers contain all the parameters optionalHeaders(req, opts) - // Create a client - c := &http.Client{Timeout: timeout * time.Second} - // Do request - res, err := c.Do(req) + res, err := c.Client.Do(req) if err != nil { return nil, nil, errors.WrapPrefix(err, fmt.Sprintf("failed HTTP request with method %s and url %s", method, urlStr), 0) @@ -139,11 +140,11 @@ func MethodWithOpts( // However, this is still nice to use because unlike a limitreader, it returns an error if the body is too large // We use this function without a writer so we pass nil // We impose a limit because servers could be malicious and send huge amounts of data - r := http.MaxBytesReader(nil, res.Body, ReadLimit) + r := http.MaxBytesReader(nil, res.Body, c.ReadLimit) body, err := io.ReadAll(r) if err != nil { return res.Header, nil, errors.WrapPrefix(err, - fmt.Sprintf("failed HTTP request with method: %s, url: %s and max bytes size: %v", method, urlStr, ReadLimit), 0) + fmt.Sprintf("failed HTTP request with method: %s, url: %s and max bytes size: %v", method, urlStr, c.ReadLimit), 0) } if res.StatusCode < 200 || res.StatusCode > 299 { return res.Header, body, errors.Wrap(&StatusError{URL: urlStr, Body: string(body), Status: res.StatusCode}, 0) diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 6d21c82..fd8a63d 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -75,6 +75,9 @@ func genVerifier() (string, error) { // OAuth defines the main structure for this package. type OAuth struct { + // The HTTP client that is used + httpClient *httpw.Client + // ISS indicates the issuer identifier of the authorization server as defined in RFC 9207 ISS string `json:"iss"` @@ -235,7 +238,9 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error { } opts := &httpw.OptionalParams{Headers: h, Body: data} now := time.Now() - _, body, err := httpw.PostWithOpts(u, opts) + + // We are sure that we have a http client because we have initialized it when starting the exchange + _, body, err := oauth.httpClient.PostWithOpts(u, opts) if err != nil { return err } @@ -278,7 +283,13 @@ func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error) } opts := &httpw.OptionalParams{Headers: h, Body: data} now := time.Now() - _, body, err := httpw.PostWithOpts(u, opts) + + // Test if we have a http client and if not recreate one + if oauth.httpClient == nil { + oauth.httpClient = httpw.NewClient() + } + + _, body, err := oauth.httpClient.PostWithOpts(u, opts) if err != nil { return nil, time.Time{}, err } @@ -481,6 +492,10 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s // Exchange starts the OAuth exchange by getting the tokens with the redirect callback // If it was unsuccessful it returns an error. func (oauth *OAuth) Exchange() error { + // If there is no HTTP client defined, create a new one + if oauth.httpClient == nil { + oauth.httpClient = httpw.NewClient() + } return oauth.tokensWithCallback() } diff --git a/internal/server/api.go b/internal/server/api.go index 2ce3db5..1ac3164 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -19,7 +19,8 @@ func APIGetEndpoints(baseURL string) (*Endpoints, error) { } u.Path = path.Join(u.Path, "/.well-known/vpn-user-portal") - _, body, err := httpw.Get(u.String()) + c := httpw.NewClient() + _, body, err := c.Get(u.String()) if err != nil { return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) } @@ -68,7 +69,12 @@ func apiAuthorized( } else { opts.Headers = http.Header{key: {val}} } - return httpw.MethodWithOpts(method, u.String(), opts) + + // Create a client if it doesn't exist + if b.httpClient == nil { + b.httpClient = httpw.NewClient() + } + return b.httpClient.Do(method, u.String(), opts) } func apiAuthorizedRetry( diff --git a/internal/server/base.go b/internal/server/base.go index dd15aff..55fff09 100644 --- a/internal/server/base.go +++ b/internal/server/base.go @@ -2,6 +2,7 @@ package server import ( "time" + "github.com/eduvpn/eduvpn-common/internal/http" ) // Base is the base type for servers. @@ -14,6 +15,7 @@ type Base struct { StartTime time.Time `json:"start_time"` EndTime time.Time `json:"expire_time"` Type string `json:"server_type"` + httpClient *http.Client } func (b *Base) InitializeEndpoints() error { -- cgit v1.2.3