diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/api.go | 23 | ||||
| -rw-r--r-- | internal/api/api_test.go | 2 | ||||
| -rw-r--r-- | internal/http/http.go | 10 |
3 files changed, 22 insertions, 13 deletions
diff --git a/internal/api/api.go b/internal/api/api.go index 931f273..9b794db 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -47,8 +47,17 @@ type ServerData struct { ProcessAuth func(context.Context, string) (string, error) // DisableAuthorize indicates whether or not new authorization requests should be disabled DisableAuthorize bool - // Transport is the HTTP transport, only used for testing currently - Transport http.RoundTripper + // transport is the HTTP transport, only used for testing currently + transport http.RoundTripper +} + +// Transport returns the transport to be used for the server +// By default it uses the transport from internal/http DefaultTransport +func (s *ServerData) Transport() http.RoundTripper { + if s.transport == nil { + return httpw.DefaultTransport + } + return s.transport } // API is the top-level struct that each method is defined on @@ -65,15 +74,11 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t cr := customRedirect(clientID) // Construct OAuth - transp := sd.Transport - // in the tests this can be non-nil - if transp == nil { - transp = httpw.TLS13Transport() - } + transp := sd.Transport() o := eduoauth.OAuth{ ClientID: clientID, EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) { - ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, sd.Transport) + ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, transp) if err != nil { return nil, err } @@ -160,7 +165,7 @@ func (a *API) authorize(ctx context.Context) (err error) { } func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { - ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport) + ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport()) if err != nil { return nil, nil, err } diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 2d17e96..f15bac4 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -199,7 +199,7 @@ func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.Ha return in, nil }, DisableAuthorize: false, - Transport: servc.Client.Transport, + transport: servc.Client.Transport, } tc := &TestCallback{t: t} diff --git a/internal/http/http.go b/internal/http/http.go index a7240e1..aeb113e 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -147,19 +147,23 @@ type Client struct { Timeout time.Duration } -// TLS13Transport returns a http.Transport with the minimum TLS version set to 1.3 -func TLS13Transport() *http.Transport { +// tls13Transport returns a http.Transport with the minimum TLS version set to 1.3 +func tls13Transport() *http.Transport { tr := http.DefaultTransport.(*http.Transport).Clone() tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS13} return tr } +// DefaultTransport is the default HTTP transport to use +// by default it is a transport that only allows TLS 1.3 +var DefaultTransport = tls13Transport() + // NewClient returns a HTTP client with some default settings func NewClient(client *http.Client) *Client { c := client if c == nil { c = &http.Client{ - Transport: TLS13Transport(), + Transport: DefaultTransport, } } // if a client is non-nil it uses its own transport |
