summaryrefslogtreecommitdiff
path: root/internal/api/api.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/api/api.go')
-rw-r--r--internal/api/api.go10
1 files changed, 7 insertions, 3 deletions
diff --git a/internal/api/api.go b/internal/api/api.go
index 9c24315..16f86af 100644
--- a/internal/api/api.go
+++ b/internal/api/api.go
@@ -47,6 +47,8 @@ type ServerData struct {
ProcessAuth func(string) string
// DisableAuthorize indicates whether or not new authorization requests should be disabled
DisableAuthorize bool
+ // Transport is the HTTP transport, only used for testing currently
+ Transport http.RoundTripper
}
// API is the top-level struct that each method is defined on
@@ -65,7 +67,7 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t
o := eduoauth.OAuth{
ClientID: clientID,
EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) {
- ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK)
+ ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, sd.Transport)
if err != nil {
return nil, err
}
@@ -79,6 +81,7 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t
TokensUpdated: func(tok eduoauth.Token) {
cb.TokensUpdated(sd.ID, sd.Type, tok)
},
+ Transport: sd.Transport,
}
if tokens != nil {
@@ -147,7 +150,7 @@ func (a *API) authorize(ctx context.Context) (err error) {
}
func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) {
- ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK)
+ ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport)
if err != nil {
return nil, nil, err
}
@@ -329,12 +332,13 @@ func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []proto
}, nil
}
-func getEndpoints(ctx context.Context, url string) (*endpoints.Endpoints, error) {
+func getEndpoints(ctx context.Context, url string, tp http.RoundTripper) (*endpoints.Endpoints, error) {
uStr, err := httpw.JoinURLPath(url, "/.well-known/vpn-user-portal")
if err != nil {
return nil, err
}
httpC := httpw.NewClient(nil)
+ httpC.Client.Transport = tp
_, body, err := httpC.Get(ctx, uStr)
if err != nil {
return nil, fmt.Errorf("failed getting server endpoints with error: %w", err)