diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2024-05-24 12:35:42 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2024-05-24 13:25:11 +0200 |
| commit | 575a0a53b149ac2da21e368ef809dd2180a878f5 (patch) | |
| tree | 539b883e597c4b4ed010208cf7583503ffddfe90 /internal | |
| parent | 8b6a7cec50711e5568abb416c87ef3995341b377 (diff) | |
API Test: Mock Transport by passing it around
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/api/api.go | 10 | ||||
| -rw-r--r-- | internal/api/api_test.go | 12 | ||||
| -rw-r--r-- | internal/api/cache.go | 5 |
3 files changed, 16 insertions, 11 deletions
diff --git a/internal/api/api.go b/internal/api/api.go index 9c24315..16f86af 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -47,6 +47,8 @@ type ServerData struct { ProcessAuth func(string) string // DisableAuthorize indicates whether or not new authorization requests should be disabled DisableAuthorize bool + // Transport is the HTTP transport, only used for testing currently + Transport http.RoundTripper } // API is the top-level struct that each method is defined on @@ -65,7 +67,7 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t o := eduoauth.OAuth{ ClientID: clientID, EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) { - ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK) + ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, sd.Transport) if err != nil { return nil, err } @@ -79,6 +81,7 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t TokensUpdated: func(tok eduoauth.Token) { cb.TokensUpdated(sd.ID, sd.Type, tok) }, + Transport: sd.Transport, } if tokens != nil { @@ -147,7 +150,7 @@ func (a *API) authorize(ctx context.Context) (err error) { } func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { - ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK) + ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport) if err != nil { return nil, nil, err } @@ -329,12 +332,13 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto }, nil } -func getEndpoints(ctx context.Context, url string) (*endpoints.Endpoints, error) { +func getEndpoints(ctx context.Context, url string, tp http.RoundTripper) (*endpoints.Endpoints, error) { uStr, err := httpw.JoinURLPath(url, "/.well-known/vpn-user-portal") if err != nil { return nil, err } httpC := httpw.NewClient(nil) + httpC.Client.Transport = tp _, body, err := httpC.Get(ctx, uStr) if err != nil { return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) diff --git a/internal/api/api_test.go b/internal/api/api_test.go index c9f75ca..c126af4 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -186,6 +186,10 @@ func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.Ha }...) // start server serv := test.NewServerWithHandles(hps, listen) + servc, err := serv.Client() + if err != nil { + t.Fatalf("failed to setup HTTP test server client: %v", servc) + } sd := ServerData{ ID: "randomidentifier", @@ -196,13 +200,9 @@ func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.Ha return in }, DisableAuthorize: false, + Transport: servc.Client.Transport, } - servc, err := serv.Client() - if err != nil { - t.Fatalf("failed to setup HTTP test server client: %v", servc) - } - // TODO: Mock underlying clients instead - http.DefaultTransport = servc.Client.Transport + tc := &TestCallback{t: t} diff --git a/internal/api/cache.go b/internal/api/cache.go index 4777334..420a4b0 100644 --- a/internal/api/cache.go +++ b/internal/api/cache.go @@ -2,6 +2,7 @@ package api import ( "context" + "net/http" "sync" "time" @@ -16,7 +17,7 @@ type EndpointCache struct { } // Get() returns a cached or fresh endpoint cache copy -func (ec *EndpointCache) Get(ctx context.Context, wk string) (*endpoints.Endpoints, error) { +func (ec *EndpointCache) Get(ctx context.Context, wk string, transport http.RoundTripper) (*endpoints.Endpoints, error) { ec.mu.Lock() defer ec.mu.Unlock() @@ -35,7 +36,7 @@ func (ec *EndpointCache) Get(ctx context.Context, wk string) (*endpoints.Endpoin } // get fresh API endpoints - ep, err := getEndpoints(ctx, wk) + ep, err := getEndpoints(ctx, wk, transport) if err != nil { return nil, err } |
