diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2024-10-25 15:27:23 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2024-10-28 17:02:14 +0100 |
| commit | 0076386bca8b1e49673f50323cd147ac080cfc2f (patch) | |
| tree | 15aa6ee6cf752db189e0b2b6f75376c9644d384d /internal/api | |
| parent | 8cd50acd5c961bd9c52f1fcbaf18ddc1015accd0 (diff) | |
API + HTTP + Exports: Cleaner TLS1.3 enforcement using a custom DefaultTransport
Also fix where TLS 1.3 was not properly enforced for the endpoint cache
Diffstat (limited to 'internal/api')
| -rw-r--r-- | internal/api/api.go | 23 | ||||
| -rw-r--r-- | internal/api/api_test.go | 2 |
2 files changed, 15 insertions, 10 deletions
diff --git a/internal/api/api.go b/internal/api/api.go index 931f273..9b794db 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -47,8 +47,17 @@ type ServerData struct { ProcessAuth func(context.Context, string) (string, error) // DisableAuthorize indicates whether or not new authorization requests should be disabled DisableAuthorize bool - // Transport is the HTTP transport, only used for testing currently - Transport http.RoundTripper + // transport is the HTTP transport, only used for testing currently + transport http.RoundTripper +} + +// Transport returns the transport to be used for the server +// By default it uses the transport from internal/http DefaultTransport +func (s *ServerData) Transport() http.RoundTripper { + if s.transport == nil { + return httpw.DefaultTransport + } + return s.transport } // API is the top-level struct that each method is defined on @@ -65,15 +74,11 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t cr := customRedirect(clientID) // Construct OAuth - transp := sd.Transport - // in the tests this can be non-nil - if transp == nil { - transp = httpw.TLS13Transport() - } + transp := sd.Transport() o := eduoauth.OAuth{ ClientID: clientID, EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) { - ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, sd.Transport) + ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, transp) if err != nil { return nil, err } @@ -160,7 +165,7 @@ func (a *API) authorize(ctx context.Context) (err error) { } func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { - ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport) + ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport()) if err != nil { return nil, nil, err } diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 2d17e96..f15bac4 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -199,7 +199,7 @@ func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.Ha return in, nil }, DisableAuthorize: false, - Transport: servc.Client.Transport, + transport: servc.Client.Transport, } tc := &TestCallback{t: t} |
