From 0076386bca8b1e49673f50323cd147ac080cfc2f Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Fri, 25 Oct 2024 15:27:23 +0200 Subject: API + HTTP + Exports: Cleaner TLS1.3 enforcement using a custom DefaultTransport Also fix where TLS 1.3 was not properly enforced for the endpoint cache --- exports/exports_test_wrapper.go | 11 +++++------ internal/api/api.go | 23 ++++++++++++++--------- internal/api/api_test.go | 2 +- internal/http/http.go | 10 +++++++--- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/exports/exports_test_wrapper.go b/exports/exports_test_wrapper.go index a52c0fe..a7305dd 100644 --- a/exports/exports_test_wrapper.go +++ b/exports/exports_test_wrapper.go @@ -24,6 +24,8 @@ import ( "github.com/eduvpn/eduvpn-common/internal/test" "github.com/eduvpn/eduvpn-common/types/error" "github.com/eduvpn/eduvpn-common/util" + + httpw "github.com/eduvpn/eduvpn-common/internal/http" ) func getString(in *C.char) string { @@ -264,8 +266,7 @@ func testServerList(t *testing.T) { t.Fatalf("failed to obtain server client: %v", err) } - // TODO: can we do this better - http.DefaultTransport = sclient.Client.Transport + httpw.DefaultTransport = sclient.Client.Transport.(*http.Transport) gerr := getError(t, AddServer(ck, 3, listS, nil)) if gerr != "" { @@ -417,8 +418,7 @@ func testGetConfig(t *testing.T) { t.Fatalf("failed to obtain server client: %v", err) } - // TODO: can we do this better - http.DefaultTransport = sclient.Client.Transport + httpw.DefaultTransport = sclient.Client.Transport.(*http.Transport) _, cfgErr := GetConfig(ck, 3, listS, 0, 0) cfgErrS := getError(t, cfgErr) @@ -501,8 +501,7 @@ func testLetsConnectDiscovery(t *testing.T) { t.Fatalf("failed to obtain server client: %v", err) } - // TODO: can we do this better - http.DefaultTransport = sclient.Client.Transport + httpw.DefaultTransport = sclient.Client.Transport.(*http.Transport) // try to add an institute access server exptErr := fmt.Sprintf("An internal error occurred. The cause of the error is: Adding a non-custom server when the client does not use discovery is not supported, identifier: %s, type: 1.", list) diff --git a/internal/api/api.go b/internal/api/api.go index 931f273..9b794db 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -47,8 +47,17 @@ type ServerData struct { ProcessAuth func(context.Context, string) (string, error) // DisableAuthorize indicates whether or not new authorization requests should be disabled DisableAuthorize bool - // Transport is the HTTP transport, only used for testing currently - Transport http.RoundTripper + // transport is the HTTP transport, only used for testing currently + transport http.RoundTripper +} + +// Transport returns the transport to be used for the server +// By default it uses the transport from internal/http DefaultTransport +func (s *ServerData) Transport() http.RoundTripper { + if s.transport == nil { + return httpw.DefaultTransport + } + return s.transport } // API is the top-level struct that each method is defined on @@ -65,15 +74,11 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t cr := customRedirect(clientID) // Construct OAuth - transp := sd.Transport - // in the tests this can be non-nil - if transp == nil { - transp = httpw.TLS13Transport() - } + transp := sd.Transport() o := eduoauth.OAuth{ ClientID: clientID, EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) { - ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, sd.Transport) + ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, transp) if err != nil { return nil, err } @@ -160,7 +165,7 @@ func (a *API) authorize(ctx context.Context) (err error) { } func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { - ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport) + ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport()) if err != nil { return nil, nil, err } diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 2d17e96..f15bac4 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -199,7 +199,7 @@ func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.Ha return in, nil }, DisableAuthorize: false, - Transport: servc.Client.Transport, + transport: servc.Client.Transport, } tc := &TestCallback{t: t} diff --git a/internal/http/http.go b/internal/http/http.go index a7240e1..aeb113e 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -147,19 +147,23 @@ type Client struct { Timeout time.Duration } -// TLS13Transport returns a http.Transport with the minimum TLS version set to 1.3 -func TLS13Transport() *http.Transport { +// tls13Transport returns a http.Transport with the minimum TLS version set to 1.3 +func tls13Transport() *http.Transport { tr := http.DefaultTransport.(*http.Transport).Clone() tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS13} return tr } +// DefaultTransport is the default HTTP transport to use +// by default it is a transport that only allows TLS 1.3 +var DefaultTransport = tls13Transport() + // NewClient returns a HTTP client with some default settings func NewClient(client *http.Client) *Client { c := client if c == nil { c = &http.Client{ - Transport: TLS13Transport(), + Transport: DefaultTransport, } } // if a client is non-nil it uses its own transport -- cgit v1.2.3