diff options
Diffstat (limited to 'internal/httpwrap')
| -rw-r--r-- | internal/httpwrap/httpwrap.go | 288 | ||||
| -rw-r--r-- | internal/httpwrap/httpwrap_test.go | 61 |
2 files changed, 349 insertions, 0 deletions
diff --git a/internal/httpwrap/httpwrap.go b/internal/httpwrap/httpwrap.go new file mode 100644 index 0000000..5fd42c8 --- /dev/null +++ b/internal/httpwrap/httpwrap.go @@ -0,0 +1,288 @@ +// Package httpwrap defines higher level helpers for the net/http package +package httpwrap + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "path" + "strings" + "time" + + "codeberg.org/eduVPN/eduvpn-common/internal/commonver" +) + +// UserAgent is the user agent that is used for requests +var UserAgent string + +// URLParameters is a type used for the parameters in the URL. +type URLParameters map[string]string + +// OptionalParams is a structure that defines the optional parameters that are given when making a HTTP call. +type OptionalParams struct { + Headers http.Header + URLParameters URLParameters + Body url.Values + Timeout time.Duration +} + +func cleanPath(u *url.URL, trailing bool) string { + if u.Path != "" { + // Clean the path + // https://pkg.go.dev/path#Clean + u.Path = path.Clean(u.Path) + } + + str := u.String() + + // Make sure the URL ends with a / + if trailing && str[len(str)-1:] != "/" { + str += "/" + } + return str +} + +// EnsureValidURL ensures that the input URL is valid to be used internally +// It does the following +// - Sets the scheme to https if none is given +// - It 'cleans' up the path using path.Clean +// - It makes sure that the URL ends with a / +// It returns an error if the URL cannot be parsed. +func EnsureValidURL(s string, trailing bool) (string, error) { + u, err := url.Parse(s) + if err != nil { + return "", fmt.Errorf("failed parsing url with error: %w", err) + } + + // Make sure the scheme is always https + if u.Scheme != "https" { + u.Scheme = "https" + } + return cleanPath(u, trailing), nil +} + +// JoinURLPath joins url's path, in go 1.19 we can use url.JoinPath +func JoinURLPath(u string, p string) (string, error) { + pu, err := url.Parse(u) + if err != nil { + return "", fmt.Errorf("failed to parse url for joining paths with error: %w", err) + } + pp, err := url.Parse(p) + if err != nil { + return "", fmt.Errorf("failed to parse path for joining paths with error: %w", err) + } + fp := pu.ResolveReference(pp) + + // We also clean the path for consistency + return cleanPath(fp, false), nil +} + +// ConstructURL creates a URL with the included parameters. +func ConstructURL(u *url.URL, params URLParameters) (string, error) { + q := u.Query() + + for p, value := range params { + q.Set(p, value) + } + u.RawQuery = q.Encode() + return u.String(), nil +} + +// optionalURL ensures that the URL contains the optional parameters +// it returns the url (with parameters if success) and an error indicating success. +func optionalURL(urlStr string, opts *OptionalParams) (string, error) { + u, err := url.Parse(urlStr) + if err != nil { + return "", fmt.Errorf("failed to construct parse url '%s' with error: %w", urlStr, err) + } + // Make sure the scheme is always set to HTTPS + if u.Scheme != "https" { + u.Scheme = "https" + } + + if opts == nil { + return u.String(), nil + } + + return ConstructURL(u, opts.URLParameters) +} + +// optionalHeaders ensures that the HTTP request uses the optional headers if defined. +func optionalHeaders(req *http.Request, opts *OptionalParams) { + // Add headers + if opts != nil && req != nil && opts.Headers != nil { + for k, v := range opts.Headers { + for _, cv := range v { + req.Header.Add(k, cv) + } + } + } +} + +// optionalBodyReader returns a HTTP body reader if there is a body, otherwise nil. +func optionalBodyReader(opts *OptionalParams) io.Reader { + if opts != nil && opts.Body != nil { + return strings.NewReader(opts.Body.Encode()) + } + return nil +} + +// Client is a wrapper around http.Client with some convenience features +// - A default timeout of 5 seconds +// - A read limiter to prevent servers from sending large amounts of data +// - Checking on http code with custom errors +type Client struct { + // Client is the HTTP Client that sends the request + Client *http.Client + // ReadLimit denotes the maximum amount of bytes that are read in HTTP responses + // This is used to prevent servers from sending huge amounts of data + // A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems + ReadLimit int64 + + // Timeout denotes the default timeout for each request + Timeout time.Duration +} + +// tls13Transport returns a http.Transport with the minimum TLS version set to 1.3 +func tls13Transport() *http.Transport { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS13} + return tr +} + +// DefaultTransport is the default HTTP transport to use +// by default it is a transport that only allows TLS 1.3 +var DefaultTransport = tls13Transport() + +// NewClient returns a HTTP client with some default settings +func NewClient(client *http.Client) *Client { + c := client + if c == nil { + c = &http.Client{ + Transport: DefaultTransport, + } + } + // if a client is non-nil it uses its own transport + // for the OAuth client we also make sure TLS 1.3 is set + // TODO: Should we double verify that MinVersion is 1.3 or is that overkill? + + // ReadLimit denotes the maximum amount of bytes that are read in HTTP responses + // This is used to prevent servers from sending huge amounts of data + // A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems + // The timeout is 10 seconds by default. We pass it here and not in the http client because we want to do it per request + return &Client{Client: c, ReadLimit: 16 << 20, Timeout: 10 * time.Second} +} + +// Get creates a Get request and returns the headers, body and an error. +func (c *Client) Get(ctx context.Context, url string) (http.Header, []byte, error) { + return c.Do(ctx, http.MethodGet, url, nil) +} + +// Do sends a HTTP request using a method (e.g. GET, POST), an url and optional parameters +// It returns the HTTP headers, the body and an error if there is one. +func (c *Client) Do(ctx context.Context, method string, urlStr string, opts *OptionalParams) (http.Header, []byte, error) { + // Make sure the url contains all the parameters + // This can return an error, + // it already has the right error, so we don't wrap it further + urlStr, err := optionalURL(urlStr, opts) + if err != nil { + // No further type wrapping is needed here + return nil, nil, err + } + + // The timeout is configurable for each request + timeout := c.Timeout + if opts != nil && opts.Timeout.Seconds() > 0 { + timeout = opts.Timeout + } + + ctx, cncl := context.WithTimeout(ctx, timeout) + defer cncl() + + slog.Debug("sending request", "method", method, "url", urlStr) + + // Create request object with the body reader generated from the optional arguments + req, err := http.NewRequestWithContext(ctx, method, urlStr, optionalBodyReader(opts)) + if err != nil { + return nil, nil, fmt.Errorf("failed HTTP request with method: '%s', url: '%s' and error: %w", method, urlStr, err) + } + if UserAgent != "" { + req.Header.Add("User-Agent", UserAgent) + } + + // Make sure the headers contain all the parameters + optionalHeaders(req, opts) + + // Do request + res, err := c.Client.Do(req) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil, nil, &TimeoutError{URL: urlStr, Method: method} + } + return nil, nil, fmt.Errorf("failed HTTP request with method: '%s', url: '%s' and error: %w", method, urlStr, err) + } + + // Request successful, make sure body is closed at the end + defer func() { + _ = res.Body.Close() + }() + + // Return a string + // A max bytes reader is normally used for request bodies with a writer + // However, this is still nice to use because unlike a limitreader, it returns an error if the body is too large + // We use this function without a writer so we pass nil + // We impose a limit because servers could be malicious and send huge amounts of data + r := http.MaxBytesReader(nil, res.Body, c.ReadLimit) + body, err := io.ReadAll(r) + if err != nil { + return res.Header, nil, fmt.Errorf("failed HTTP request with method: '%s', url: '%s', max bytes size: '%v' and error: %w", method, urlStr, c.ReadLimit, err) + } + if res.StatusCode < 200 || res.StatusCode > 299 { + return res.Header, body, fmt.Errorf("failed HTTP request with method: '%s' due to a status error: %w", method, &StatusError{URL: urlStr, Body: string(body), Status: res.StatusCode}) + } + + // Return the body in bytes and signal the status error if there was one + return res.Header, body, nil +} + +// TimeoutError indicates that we have gotten a timeout +type TimeoutError struct { + URL string + Method string +} + +// Error returns the TimeoutError as an error string. +func (e *TimeoutError) Error() string { + return fmt.Sprintf( + "timeout in obtaining HTTP resource: '%s' with method: '%s'", + e.URL, + e.Method, + ) +} + +// StatusError indicates that we have received a HTTP status error. +type StatusError struct { + URL string + Body string + Status int +} + +// Error returns the StatusError as an error string. +func (e *StatusError) Error() string { + return fmt.Sprintf( + "failed obtaining HTTP resource: '%s' as it gave an unsuccessful status code: '%d'. Body: '%s'", + e.URL, + e.Status, + e.Body, + ) +} + +// RegisterAgent registers the user agent for client and version +func RegisterAgent(client string, verApp string) { + UserAgent = fmt.Sprintf("%s/%s eduvpn-common/%s", client, verApp, commonver.Version) +} diff --git a/internal/httpwrap/httpwrap_test.go b/internal/httpwrap/httpwrap_test.go new file mode 100644 index 0000000..422ee3f --- /dev/null +++ b/internal/httpwrap/httpwrap_test.go @@ -0,0 +1,61 @@ +package httpwrap + +import ( + "testing" +) + +func TestEnsureValidURL(t *testing.T) { + _, validErr := EnsureValidURL("%notvalid%", true) + + if validErr == nil { + t.Fatal("Got nil error, want: non-nil") + } + + testCases := map[string]string{ + // Make sure we set https + "example.com/": "https://example.com/", + // Make sure we do override the scheme to https + "http://example.com/": "https://example.com/", + // This URL is already valid + "https://example.com/": "https://example.com/", + // Make sure to add a trailing slash (/) + "https://example.com": "https://example.com/", + // Cleanup the path 1 + "https://example.com/////": "https://example.com/", + // Cleanup the path 2 + "https://example.com/..": "https://example.com/", + } + + for k, v := range testCases { + valid, validErr := EnsureValidURL(k, true) + if validErr != nil { + t.Fatalf("Got: %v, want: nil", validErr) + } + if valid != v { + t.Fatalf("Got: %v, want: %v", valid, v) + } + } +} + +func Test_JoinURLPath(t *testing.T) { + cases := []struct { + u string + p string + want string + }{ + {u: "https://example.com", p: "test", want: "https://example.com/test"}, + {u: "https://example.com", p: "/test", want: "https://example.com/test"}, + {u: "https://example.com", p: "../test", want: "https://example.com/test"}, + {u: "https://example.com", p: "../test/", want: "https://example.com/test"}, + {u: "https://example.com", p: "test/", want: "https://example.com/test"}, + } + for _, c := range cases { + got, err := JoinURLPath(c.u, c.p) + if err != nil { + t.Fatalf("Failed to parse join url case: %v, err: %v", c, err) + } + if got != c.want { + t.Fatalf("Failed test case for joining URL, want: %v, got: %v", c.want, got) + } + } +} |
