summaryrefslogtreecommitdiff
path: root/internal/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api')
-rw-r--r--internal/api/api.go10
-rw-r--r--internal/api/api_test.go12
-rw-r--r--internal/api/cache.go5
3 files changed, 16 insertions, 11 deletions
diff --git a/internal/api/api.go b/internal/api/api.go
index 9c24315..16f86af 100644
--- a/internal/api/api.go
+++ b/internal/api/api.go
@@ -47,6 +47,8 @@ type ServerData struct {
ProcessAuth func(string) string
// DisableAuthorize indicates whether or not new authorization requests should be disabled
DisableAuthorize bool
+ // Transport is the HTTP transport, only used for testing currently
+ Transport http.RoundTripper
}
// API is the top-level struct that each method is defined on
@@ -65,7 +67,7 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t
o := eduoauth.OAuth{
ClientID: clientID,
EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) {
- ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK)
+ ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, sd.Transport)
if err != nil {
return nil, err
}
@@ -79,6 +81,7 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t
TokensUpdated: func(tok eduoauth.Token) {
cb.TokensUpdated(sd.ID, sd.Type, tok)
},
+ Transport: sd.Transport,
}
if tokens != nil {
@@ -147,7 +150,7 @@ func (a *API) authorize(ctx context.Context) (err error) {
}
func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) {
- ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK)
+ ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport)
if err != nil {
return nil, nil, err
}
@@ -329,12 +332,13 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto
}, nil
}
-func getEndpoints(ctx context.Context, url string) (*endpoints.Endpoints, error) {
+func getEndpoints(ctx context.Context, url string, tp http.RoundTripper) (*endpoints.Endpoints, error) {
uStr, err := httpw.JoinURLPath(url, "/.well-known/vpn-user-portal")
if err != nil {
return nil, err
}
httpC := httpw.NewClient(nil)
+ httpC.Client.Transport = tp
_, body, err := httpC.Get(ctx, uStr)
if err != nil {
return nil, fmt.Errorf("failed getting server endpoints with error: %w", err)
diff --git a/internal/api/api_test.go b/internal/api/api_test.go
index c9f75ca..c126af4 100644
--- a/internal/api/api_test.go
+++ b/internal/api/api_test.go
@@ -186,6 +186,10 @@ func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.Ha
}...)
// start server
serv := test.NewServerWithHandles(hps, listen)
+ servc, err := serv.Client()
+ if err != nil {
+ t.Fatalf("failed to setup HTTP test server client: %v", servc)
+ }
sd := ServerData{
ID: "randomidentifier",
@@ -196,13 +200,9 @@ func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.Ha
return in
},
DisableAuthorize: false,
+ Transport: servc.Client.Transport,
}
- servc, err := serv.Client()
- if err != nil {
- t.Fatalf("failed to setup HTTP test server client: %v", servc)
- }
- // TODO: Mock underlying clients instead
- http.DefaultTransport = servc.Client.Transport
+
tc := &TestCallback{t: t}
diff --git a/internal/api/cache.go b/internal/api/cache.go
index 4777334..420a4b0 100644
--- a/internal/api/cache.go
+++ b/internal/api/cache.go
@@ -2,6 +2,7 @@ package api
import (
"context"
+ "net/http"
"sync"
"time"
@@ -16,7 +17,7 @@ type EndpointCache struct {
}
// Get() returns a cached or fresh endpoint cache copy
-func (ec *EndpointCache) Get(ctx context.Context, wk string) (*endpoints.Endpoints, error) {
+func (ec *EndpointCache) Get(ctx context.Context, wk string, transport http.RoundTripper) (*endpoints.Endpoints, error) {
ec.mu.Lock()
defer ec.mu.Unlock()
@@ -35,7 +36,7 @@ func (ec *EndpointCache) Get(ctx context.Context, wk string) (*endpoints.Endpoin
}
// get fresh API endpoints
- ep, err := getEndpoints(ctx, wk)
+ ep, err := getEndpoints(ctx, wk, transport)
if err != nil {
return nil, err
}