summaryrefslogtreecommitdiff
path: root/internal/api/api.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api/api.go')
-rw-r--r--internal/api/api.go23
1 files changed, 14 insertions, 9 deletions
diff --git a/internal/api/api.go b/internal/api/api.go
index 931f273..9b794db 100644
--- a/internal/api/api.go
+++ b/internal/api/api.go
@@ -47,8 +47,17 @@ type ServerData struct {
ProcessAuth func(context.Context, string) (string, error)
// DisableAuthorize indicates whether or not new authorization requests should be disabled
DisableAuthorize bool
- // Transport is the HTTP transport, only used for testing currently
- Transport http.RoundTripper
+ // transport is the HTTP transport, only used for testing currently
+ transport http.RoundTripper
+}
+
+// Transport returns the transport to be used for the server
+// By default it uses the transport from internal/http DefaultTransport
+func (s *ServerData) Transport() http.RoundTripper {
+ if s.transport == nil {
+ return httpw.DefaultTransport
+ }
+ return s.transport
}
// API is the top-level struct that each method is defined on
@@ -65,15 +74,11 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t
cr := customRedirect(clientID)
// Construct OAuth
- transp := sd.Transport
- // in the tests this can be non-nil
- if transp == nil {
- transp = httpw.TLS13Transport()
- }
+ transp := sd.Transport()
o := eduoauth.OAuth{
ClientID: clientID,
EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) {
- ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, sd.Transport)
+ ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, transp)
if err != nil {
return nil, err
}
@@ -160,7 +165,7 @@ func (a *API) authorize(ctx context.Context) (err error) {
}
func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) {
- ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport)
+ ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport())
if err != nil {
return nil, nil, err
}