diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2023-02-16 15:48:52 +0100 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2023-02-16 15:52:05 +0100 |
| commit | 807140ce43584e9612f7b5890b13d751247f8e6e (patch) | |
| tree | 08e05fd79078f5093bc7aea68557b212bb5c1bfa /internal/server/api_test.go | |
| parent | f718788442682f87e2fd1b6067f6062bade52d52 (diff) | |
Server: Validate endpoints
This commit validates the server endpoints by checking the Host and
scheme of each URL to check if they match eachother. This is to
prevent further mixup attacks
Diffstat (limited to 'internal/server/api_test.go')
| -rw-r--r-- | internal/server/api_test.go | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/internal/server/api_test.go b/internal/server/api_test.go new file mode 100644 index 0000000..00fba3d --- /dev/null +++ b/internal/server/api_test.go @@ -0,0 +1,132 @@ +package server + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/eduvpn/eduvpn-common/internal/test" + "github.com/go-errors/errors" +) + + +func getErrorMsg(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func compareEndpoints(ep1 Endpoints, ep2 Endpoints) bool { + v3_1 := ep1.API.V3 + v3_2 := ep2.API.V3 + return v3_1.API == v3_2.API && v3_1.Authorization == v3_2.Authorization && v3_1.Token == v3_2.Token +} + +func Test_APIGetEndpoints(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello!") + }) + hs := &test.HandlerSet{} + hs.SetHandler(handler) + s := test.NewServer(hs) + defer s.Close() + + c, err := s.Client() + if err != nil { + t.Fatalf("failed to get client for test server endpoints: %v", err) + } + + testCases := []struct { + epl EndpointList + err error + }{ + { + epl: EndpointList{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: nil, + }, + { + epl: EndpointList{ + API: "http://example.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API scheme: 'http', is not equal to authorization scheme: 'https'"), + }, + { + epl: EndpointList{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "ftp://example.com/3", + }, + err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"), + }, + { + epl: EndpointList{ + API: "https://malicious.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API host: 'malicious.com', is not equal to authorization host: 'example.com'"), + }, + { + epl: EndpointList{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "https://malicious.com/3", + }, + err: errors.New("API host: 'example.com', is not equal to token host: 'malicious.com'"), + }, + { + epl: EndpointList{ + API: "https://example.com/1", + Authorization: "https://malicious.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API host: 'example.com', is not equal to authorization host: 'malicious.com'"), + }, + } + + for _, tc := range testCases { + ep := &Endpoints{ + API: EndpointsVersions{ + V3: tc.epl, + }, + } + // Update the handler + hs.SetHandler(http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + jsonStr, err := json.Marshal(ep) + if err != nil { + t.Fatalf("failed to marshal JSON for test case: %v, err: %v", tc, err) + } + + fmt.Fprintln(w, string(jsonStr)) + + })) + gotEP, err := APIGetEndpoints(s.URL, c) + if getErrorMsg(err) != getErrorMsg(tc.err) { + t.Fatalf("Errors not equal, want err: %v, got: %v", tc.err, err) + } + // The error was not nil, continue because endpoints should not be compared + if tc.err != nil { + continue + } + if ep == nil { + t.Fatalf("No test case endpoints") + } + if gotEP == nil { + t.Fatalf("Got no endpoints for nil error") + } + // if no error then the endpoints should be equal + if !compareEndpoints(*ep, *gotEP) { + t.Fatalf("Endpoints are not equal, got: %v, want: %v", gotEP, ep) + } + } +} |
