From 807140ce43584e9612f7b5890b13d751247f8e6e Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Thu, 16 Feb 2023 15:48:52 +0100 Subject: Server: Validate endpoints This commit validates the server endpoints by checking the Host and scheme of each URL to check if they match eachother. This is to prevent further mixup attacks --- internal/server/api.go | 51 ++++++++++++++--- internal/server/api_test.go | 132 ++++++++++++++++++++++++++++++++++++++++++++ internal/server/base.go | 2 +- internal/server/server.go | 10 ++-- 4 files changed, 181 insertions(+), 14 deletions(-) create mode 100644 internal/server/api_test.go (limited to 'internal/server') diff --git a/internal/server/api.go b/internal/server/api.go index d65e923..bede643 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -12,25 +12,58 @@ import ( "github.com/go-errors/errors" ) -func APIGetEndpoints(baseURL string) (*Endpoints, error) { - u, err := url.Parse(baseURL) +func validateEndpoints(endpoints Endpoints) error { + v3 := endpoints.API.V3 + pAPI, err := url.Parse(v3.API) if err != nil { - return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) + return errors.WrapPrefix(err, "failed to parse API endpoint", 0) + } + pAuth, err := url.Parse(v3.Authorization) + if err != nil { + return errors.WrapPrefix(err, "failed to parse API authorization endpoint", 0) + } + pToken, err := url.Parse(v3.Token) + if err != nil { + return errors.WrapPrefix(err, "failed to parse API token endpoint", 0) + } + if pAPI.Scheme != pAuth.Scheme { + return errors.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme) + } + if pAPI.Scheme != pToken.Scheme { + return errors.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme) + } + if pAPI.Host != pAuth.Host { + return errors.Errorf("API host: '%v', is not equal to authorization host: '%v'", pAPI.Host, pAuth.Host) } + if pAPI.Host != pToken.Host { + return errors.Errorf("API host: '%v', is not equal to token host: '%v'", pAPI.Host, pToken.Host) + } + return nil +} - u.Path = path.Join(u.Path, "/.well-known/vpn-user-portal") - c := httpw.NewClient() - _, body, err := c.Get(u.String()) +func APIGetEndpoints(baseURL string, client *httpw.Client) (*Endpoints, error) { + uStr, err := httpw.JoinURLPath(baseURL, "/.well-known/vpn-user-portal") + if err != nil { + return nil, err + } + if client == nil { + client = httpw.NewClient() + } + _, body, err := client.Get(uStr) if err != nil { return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) } - ep := &Endpoints{} - if err = json.Unmarshal(body, ep); err != nil { + ep := Endpoints{} + if err = json.Unmarshal(body, &ep); err != nil { return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) } + err = validateEndpoints(ep) + if err != nil { + return nil, err + } - return ep, nil + return &ep, nil } func apiAuthorized( diff --git a/internal/server/api_test.go b/internal/server/api_test.go new file mode 100644 index 0000000..00fba3d --- /dev/null +++ b/internal/server/api_test.go @@ -0,0 +1,132 @@ +package server + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/eduvpn/eduvpn-common/internal/test" + "github.com/go-errors/errors" +) + + +func getErrorMsg(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func compareEndpoints(ep1 Endpoints, ep2 Endpoints) bool { + v3_1 := ep1.API.V3 + v3_2 := ep2.API.V3 + return v3_1.API == v3_2.API && v3_1.Authorization == v3_2.Authorization && v3_1.Token == v3_2.Token +} + +func Test_APIGetEndpoints(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello!") + }) + hs := &test.HandlerSet{} + hs.SetHandler(handler) + s := test.NewServer(hs) + defer s.Close() + + c, err := s.Client() + if err != nil { + t.Fatalf("failed to get client for test server endpoints: %v", err) + } + + testCases := []struct { + epl EndpointList + err error + }{ + { + epl: EndpointList{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: nil, + }, + { + epl: EndpointList{ + API: "http://example.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API scheme: 'http', is not equal to authorization scheme: 'https'"), + }, + { + epl: EndpointList{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "ftp://example.com/3", + }, + err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"), + }, + { + epl: EndpointList{ + API: "https://malicious.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API host: 'malicious.com', is not equal to authorization host: 'example.com'"), + }, + { + epl: EndpointList{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "https://malicious.com/3", + }, + err: errors.New("API host: 'example.com', is not equal to token host: 'malicious.com'"), + }, + { + epl: EndpointList{ + API: "https://example.com/1", + Authorization: "https://malicious.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API host: 'example.com', is not equal to authorization host: 'malicious.com'"), + }, + } + + for _, tc := range testCases { + ep := &Endpoints{ + API: EndpointsVersions{ + V3: tc.epl, + }, + } + // Update the handler + hs.SetHandler(http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + jsonStr, err := json.Marshal(ep) + if err != nil { + t.Fatalf("failed to marshal JSON for test case: %v, err: %v", tc, err) + } + + fmt.Fprintln(w, string(jsonStr)) + + })) + gotEP, err := APIGetEndpoints(s.URL, c) + if getErrorMsg(err) != getErrorMsg(tc.err) { + t.Fatalf("Errors not equal, want err: %v, got: %v", tc.err, err) + } + // The error was not nil, continue because endpoints should not be compared + if tc.err != nil { + continue + } + if ep == nil { + t.Fatalf("No test case endpoints") + } + if gotEP == nil { + t.Fatalf("Got no endpoints for nil error") + } + // if no error then the endpoints should be equal + if !compareEndpoints(*ep, *gotEP) { + t.Fatalf("Endpoints are not equal, got: %v, want: %v", gotEP, ep) + } + } +} diff --git a/internal/server/base.go b/internal/server/base.go index 55fff09..87bb488 100644 --- a/internal/server/base.go +++ b/internal/server/base.go @@ -19,7 +19,7 @@ type Base struct { } func (b *Base) InitializeEndpoints() error { - ep, err := APIGetEndpoints(b.URL) + ep, err := APIGetEndpoints(b.URL, b.httpClient) if err != nil { return err } diff --git a/internal/server/server.go b/internal/server/server.go index 00324a2..c68916e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -33,12 +33,14 @@ type EndpointList struct { Token string `json:"token_endpoint"` } +type EndpointsVersions struct { + V2 EndpointList `json:"http://eduvpn.org/api#2"` + V3 EndpointList `json:"http://eduvpn.org/api#3"` +} + // Endpoints defines the json format for /.well-known/vpn-user-portal". type Endpoints struct { - API struct { - V2 EndpointList `json:"http://eduvpn.org/api#2"` - V3 EndpointList `json:"http://eduvpn.org/api#3"` - } `json:"api"` + API EndpointsVersions `json:"api"` V string `json:"v"` } -- cgit v1.2.3