summaryrefslogtreecommitdiff
path: root/internal/api/api.go
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2024-10-25 15:27:23 +0200
committerJeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com>2024-10-28 17:02:14 +0100
commit0076386bca8b1e49673f50323cd147ac080cfc2f (patch)
tree15aa6ee6cf752db189e0b2b6f75376c9644d384d /internal/api/api.go
parent8cd50acd5c961bd9c52f1fcbaf18ddc1015accd0 (diff)
API + HTTP + Exports: Cleaner TLS1.3 enforcement using a custom DefaultTransport
Also fix where TLS 1.3 was not properly enforced for the endpoint cache
Diffstat (limited to 'internal/api/api.go')
-rw-r--r--internal/api/api.go23
1 files changed, 14 insertions, 9 deletions
diff --git a/internal/api/api.go b/internal/api/api.go
index 931f273..9b794db 100644
--- a/internal/api/api.go
+++ b/internal/api/api.go
@@ -47,8 +47,17 @@ type ServerData struct {
ProcessAuth func(context.Context, string) (string, error)
// DisableAuthorize indicates whether or not new authorization requests should be disabled
DisableAuthorize bool
- // Transport is the HTTP transport, only used for testing currently
- Transport http.RoundTripper
+ // transport is the HTTP transport, only used for testing currently
+ transport http.RoundTripper
+}
+
+// Transport returns the transport to be used for the server
+// By default it uses the transport from internal/http DefaultTransport
+func (s *ServerData) Transport() http.RoundTripper {
+ if s.transport == nil {
+ return httpw.DefaultTransport
+ }
+ return s.transport
}
// API is the top-level struct that each method is defined on
@@ -65,15 +74,11 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t
cr := customRedirect(clientID)
// Construct OAuth
- transp := sd.Transport
- // in the tests this can be non-nil
- if transp == nil {
- transp = httpw.TLS13Transport()
- }
+ transp := sd.Transport()
o := eduoauth.OAuth{
ClientID: clientID,
EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) {
- ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, sd.Transport)
+ ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, transp)
if err != nil {
return nil, err
}
@@ -160,7 +165,7 @@ func (a *API) authorize(ctx context.Context) (err error) {
}
func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) {
- ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport)
+ ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport())
if err != nil {
return nil, nil, err
}