summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2023-01-06 14:53:34 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2023-01-06 14:59:10 +0100
commitff70e291c96de23ae4dab20f9c4e9f895eee53d5 (patch)
treefc46d7eced911fb083fa3c58bcd6d173e3afa82d /internal
parent965e36a91ca4eb6614ee71d0352ef42b465eb54f (diff)
Refactor: Re-use a HTTP client
Diffstat (limited to 'internal')
-rw-r--r--internal/discovery/discovery.go20
-rw-r--r--internal/http/http.go75
-rw-r--r--internal/oauth/oauth.go19
-rw-r--r--internal/server/api.go10
-rw-r--r--internal/server/base.go2
5 files changed, 79 insertions, 47 deletions
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index b2a90cd..41685ac 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -14,6 +14,9 @@ import (
// Discovery is the main structure used for this package.
type Discovery struct {
+ // The httpClient for sending HTTP requests
+ httpClient *http.Client
+
// organizations represents the organizations that are returned by the discovery server
organizations types.DiscoveryOrganizations
@@ -23,12 +26,17 @@ type Discovery struct {
var DiscoURL = "https://disco.eduvpn.org/v2/"
-// discoFile is a helper function that gets a disco JSON and fills the structure with it
+// file is a helper function that gets a disco JSON and fills the structure with it
// If it was unsuccessful it returns an error.
-func discoFile(jsonFile string, previousVersion uint64, structure interface{}) error {
+func (discovery *Discovery) file(jsonFile string, previousVersion uint64, structure interface{}) error {
+ // No HTTP client present, create one
+ if discovery.httpClient == nil {
+ discovery.httpClient = http.NewClient()
+ }
+
// Get json data
jsonURL := DiscoURL + jsonFile
- _, body, err := http.Get(jsonURL)
+ _, body, err := discovery.httpClient.Get(jsonURL)
if err != nil {
return err
}
@@ -36,7 +44,7 @@ func discoFile(jsonFile string, previousVersion uint64, structure interface{}) e
// Get signature
sigFile := jsonFile + ".minisig"
sigURL := DiscoURL + sigFile
- _, sigBody, err := http.Get(sigURL)
+ _, sigBody, err := discovery.httpClient.Get(sigURL)
if err != nil {
return err
}
@@ -162,7 +170,7 @@ func (discovery *Discovery) Organizations() (*types.DiscoveryOrganizations, erro
return &discovery.organizations, nil
}
file := "organization_list.json"
- err := discoFile(file, discovery.organizations.Version, &discovery.organizations)
+ err := discovery.file(file, discovery.organizations.Version, &discovery.organizations)
if err != nil {
// Return previous with an error
return &discovery.organizations, err
@@ -178,7 +186,7 @@ func (discovery *Discovery) Servers() (*types.DiscoveryServers, error) {
return &discovery.servers, nil
}
file := "server_list.json"
- err := discoFile(file, discovery.servers.Version, &discovery.servers)
+ err := discovery.file(file, discovery.servers.Version, &discovery.servers)
if err != nil {
// Return previous with an error
return &discovery.servers, err
diff --git a/internal/http/http.go b/internal/http/http.go
index 81ab822..2512b40 100644
--- a/internal/http/http.go
+++ b/internal/http/http.go
@@ -20,7 +20,6 @@ type OptionalParams struct {
Headers http.Header
URLParameters URLParameters
Body url.Values
- Timeout time.Duration
}
// ConstructURL creates a URL with the included parameters.
@@ -40,16 +39,6 @@ func ConstructURL(baseURL string, params URLParameters) (string, error) {
return u.String(), nil
}
-// Get creates a Get request and returns the headers, body and an error.
-func Get(url string) (http.Header, []byte, error) {
- return MethodWithOpts(http.MethodGet, url, nil)
-}
-
-// PostWithOpts creates a Post request with optional parameters and returns the headers, body and an error.
-func PostWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error) {
- return MethodWithOpts(http.MethodPost, url, opts)
-}
-
// optionalURL ensures that the URL contains the optional parameters
// it returns the url (with parameters if success) and an error indicating success.
func optionalURL(urlStr string, opts *OptionalParams) (string, error) {
@@ -78,18 +67,43 @@ func optionalBodyReader(opts *OptionalParams) io.Reader {
return nil
}
-// ReadLimit denotes the maximum amount of bytes that are read in HTTP responses
-// This is used to prevent servers from sending huge amounts of data
-// A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems
-var ReadLimit int64 = 16 << 20
+// Client is a wrapper around http.Client with some convenience features
+// - A default timeout of 5 seconds
+// - A read limiter to prevent servers from sending large amounts of data
+// - Checking on http code with custom errors
+type Client struct {
+ // Client is the HTTP Client that sends the request
+ Client *http.Client
+ // ReadLimit denotes the maximum amount of bytes that are read in HTTP responses
+ // This is used to prevent servers from sending huge amounts of data
+ // A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems
+ ReadLimit int64
+}
+
+// Returns a HTTP client with some default settings
+func NewClient() *Client {
+ // The timeout is 5 seconds by default
+ c := &http.Client{Timeout: 5 * time.Second}
+ // ReadLimit denotes the maximum amount of bytes that are read in HTTP responses
+ // This is used to prevent servers from sending huge amounts of data
+ // A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems
+ return &Client{Client: c, ReadLimit: 16 << 20}
+}
+
+// Get creates a Get request and returns the headers, body and an error.
+func (c *Client) Get(url string) (http.Header, []byte, error) {
+ return c.Do(http.MethodGet, url, nil)
+}
+
+// PostWithOpts creates a Post request with optional parameters and returns the headers, body and an error.
+func (c *Client) PostWithOpts(url string, opts *OptionalParams) (http.Header, []byte, error) {
+ return c.Do(http.MethodPost, url, opts)
+}
-// MethodWithOpts creates a HTTP request using a method (e.g. GET, POST), an url and optional parameters
+
+// MethodWithOpts Do send a HTTP request using a method (e.g. GET, POST), an url and optional parameters
// It returns the HTTP headers, the body and an error if there is one.
-func MethodWithOpts(
- method string,
- urlStr string,
- opts *OptionalParams,
-) (http.Header, []byte, error) {
+func (c *Client) Do(method string, urlStr string, opts *OptionalParams) (http.Header, []byte, error) {
// Make sure the url contains all the parameters
// This can return an error,
// it already has the right error, so we don't wrap it further
@@ -99,13 +113,6 @@ func MethodWithOpts(
return nil, nil, err
}
- // Default timeout is 5 seconds
- // If a different timeout is given, set it
- var timeout time.Duration = 5
- if opts != nil && opts.Timeout > 0 {
- timeout = opts.Timeout
- }
-
// Create request object with the body reader generated from the optional arguments
req, err := http.NewRequest(method, urlStr, optionalBodyReader(opts))
if err != nil {
@@ -113,17 +120,11 @@ func MethodWithOpts(
fmt.Sprintf("failed HTTP request with method %s and url %s", method, urlStr), 0)
}
- // See https://stackoverflow.com/questions/17714494/golang-http-request-results-in-eof-errors-when-making-multiple-requests-successi
- req.Close = true
-
// Make sure the headers contain all the parameters
optionalHeaders(req, opts)
- // Create a client
- c := &http.Client{Timeout: timeout * time.Second}
-
// Do request
- res, err := c.Do(req)
+ res, err := c.Client.Do(req)
if err != nil {
return nil, nil, errors.WrapPrefix(err,
fmt.Sprintf("failed HTTP request with method %s and url %s", method, urlStr), 0)
@@ -139,11 +140,11 @@ func MethodWithOpts(
// However, this is still nice to use because unlike a limitreader, it returns an error if the body is too large
// We use this function without a writer so we pass nil
// We impose a limit because servers could be malicious and send huge amounts of data
- r := http.MaxBytesReader(nil, res.Body, ReadLimit)
+ r := http.MaxBytesReader(nil, res.Body, c.ReadLimit)
body, err := io.ReadAll(r)
if err != nil {
return res.Header, nil, errors.WrapPrefix(err,
- fmt.Sprintf("failed HTTP request with method: %s, url: %s and max bytes size: %v", method, urlStr, ReadLimit), 0)
+ fmt.Sprintf("failed HTTP request with method: %s, url: %s and max bytes size: %v", method, urlStr, c.ReadLimit), 0)
}
if res.StatusCode < 200 || res.StatusCode > 299 {
return res.Header, body, errors.Wrap(&StatusError{URL: urlStr, Body: string(body), Status: res.StatusCode}, 0)
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 6d21c82..fd8a63d 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -75,6 +75,9 @@ func genVerifier() (string, error) {
// OAuth defines the main structure for this package.
type OAuth struct {
+ // The HTTP client that is used
+ httpClient *httpw.Client
+
// ISS indicates the issuer identifier of the authorization server as defined in RFC 9207
ISS string `json:"iss"`
@@ -235,7 +238,9 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
}
opts := &httpw.OptionalParams{Headers: h, Body: data}
now := time.Now()
- _, body, err := httpw.PostWithOpts(u, opts)
+
+ // We are sure that we have a http client because we have initialized it when starting the exchange
+ _, body, err := oauth.httpClient.PostWithOpts(u, opts)
if err != nil {
return err
}
@@ -278,7 +283,13 @@ func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error)
}
opts := &httpw.OptionalParams{Headers: h, Body: data}
now := time.Now()
- _, body, err := httpw.PostWithOpts(u, opts)
+
+ // Test if we have a http client and if not recreate one
+ if oauth.httpClient == nil {
+ oauth.httpClient = httpw.NewClient()
+ }
+
+ _, body, err := oauth.httpClient.PostWithOpts(u, opts)
if err != nil {
return nil, time.Time{}, err
}
@@ -481,6 +492,10 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s
// Exchange starts the OAuth exchange by getting the tokens with the redirect callback
// If it was unsuccessful it returns an error.
func (oauth *OAuth) Exchange() error {
+ // If there is no HTTP client defined, create a new one
+ if oauth.httpClient == nil {
+ oauth.httpClient = httpw.NewClient()
+ }
return oauth.tokensWithCallback()
}
diff --git a/internal/server/api.go b/internal/server/api.go
index 2ce3db5..1ac3164 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -19,7 +19,8 @@ func APIGetEndpoints(baseURL string) (*Endpoints, error) {
}
u.Path = path.Join(u.Path, "/.well-known/vpn-user-portal")
- _, body, err := httpw.Get(u.String())
+ c := httpw.NewClient()
+ _, body, err := c.Get(u.String())
if err != nil {
return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
@@ -68,7 +69,12 @@ func apiAuthorized(
} else {
opts.Headers = http.Header{key: {val}}
}
- return httpw.MethodWithOpts(method, u.String(), opts)
+
+ // Create a client if it doesn't exist
+ if b.httpClient == nil {
+ b.httpClient = httpw.NewClient()
+ }
+ return b.httpClient.Do(method, u.String(), opts)
}
func apiAuthorizedRetry(
diff --git a/internal/server/base.go b/internal/server/base.go
index dd15aff..55fff09 100644
--- a/internal/server/base.go
+++ b/internal/server/base.go
@@ -2,6 +2,7 @@ package server
import (
"time"
+ "github.com/eduvpn/eduvpn-common/internal/http"
)
// Base is the base type for servers.
@@ -14,6 +15,7 @@ type Base struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"expire_time"`
Type string `json:"server_type"`
+ httpClient *http.Client
}
func (b *Base) InitializeEndpoints() error {