summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2023-02-16 15:48:52 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2023-02-16 15:52:05 +0100
commit807140ce43584e9612f7b5890b13d751247f8e6e (patch)
tree08e05fd79078f5093bc7aea68557b212bb5c1bfa
parentf718788442682f87e2fd1b6067f6062bade52d52 (diff)
Server: Validate endpoints
This commit validates the server endpoints by checking the Host and scheme of each URL to check if they match eachother. This is to prevent further mixup attacks
-rw-r--r--internal/server/api.go51
-rw-r--r--internal/server/api_test.go132
-rw-r--r--internal/server/base.go2
-rw-r--r--internal/server/server.go10
-rw-r--r--internal/test/handler.go25
5 files changed, 206 insertions, 14 deletions
diff --git a/internal/server/api.go b/internal/server/api.go
index d65e923..bede643 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -12,25 +12,58 @@ import (
"github.com/go-errors/errors"
)
-func APIGetEndpoints(baseURL string) (*Endpoints, error) {
- u, err := url.Parse(baseURL)
+func validateEndpoints(endpoints Endpoints) error {
+ v3 := endpoints.API.V3
+ pAPI, err := url.Parse(v3.API)
if err != nil {
- return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
+ return errors.WrapPrefix(err, "failed to parse API endpoint", 0)
+ }
+ pAuth, err := url.Parse(v3.Authorization)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed to parse API authorization endpoint", 0)
+ }
+ pToken, err := url.Parse(v3.Token)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed to parse API token endpoint", 0)
+ }
+ if pAPI.Scheme != pAuth.Scheme {
+ return errors.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme)
+ }
+ if pAPI.Scheme != pToken.Scheme {
+ return errors.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme)
+ }
+ if pAPI.Host != pAuth.Host {
+ return errors.Errorf("API host: '%v', is not equal to authorization host: '%v'", pAPI.Host, pAuth.Host)
}
+ if pAPI.Host != pToken.Host {
+ return errors.Errorf("API host: '%v', is not equal to token host: '%v'", pAPI.Host, pToken.Host)
+ }
+ return nil
+}
- u.Path = path.Join(u.Path, "/.well-known/vpn-user-portal")
- c := httpw.NewClient()
- _, body, err := c.Get(u.String())
+func APIGetEndpoints(baseURL string, client *httpw.Client) (*Endpoints, error) {
+ uStr, err := httpw.JoinURLPath(baseURL, "/.well-known/vpn-user-portal")
+ if err != nil {
+ return nil, err
+ }
+ if client == nil {
+ client = httpw.NewClient()
+ }
+ _, body, err := client.Get(uStr)
if err != nil {
return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
- ep := &Endpoints{}
- if err = json.Unmarshal(body, ep); err != nil {
+ ep := Endpoints{}
+ if err = json.Unmarshal(body, &ep); err != nil {
return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
+ err = validateEndpoints(ep)
+ if err != nil {
+ return nil, err
+ }
- return ep, nil
+ return &ep, nil
}
func apiAuthorized(
diff --git a/internal/server/api_test.go b/internal/server/api_test.go
new file mode 100644
index 0000000..00fba3d
--- /dev/null
+++ b/internal/server/api_test.go
@@ -0,0 +1,132 @@
+package server
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "github.com/eduvpn/eduvpn-common/internal/test"
+ "github.com/go-errors/errors"
+)
+
+
+func getErrorMsg(err error) string {
+ if err == nil {
+ return ""
+ }
+ return err.Error()
+}
+
+func compareEndpoints(ep1 Endpoints, ep2 Endpoints) bool {
+ v3_1 := ep1.API.V3
+ v3_2 := ep2.API.V3
+ return v3_1.API == v3_2.API && v3_1.Authorization == v3_2.Authorization && v3_1.Token == v3_2.Token
+}
+
+func Test_APIGetEndpoints(t *testing.T) {
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello!")
+ })
+ hs := &test.HandlerSet{}
+ hs.SetHandler(handler)
+ s := test.NewServer(hs)
+ defer s.Close()
+
+ c, err := s.Client()
+ if err != nil {
+ t.Fatalf("failed to get client for test server endpoints: %v", err)
+ }
+
+ testCases := []struct {
+ epl EndpointList
+ err error
+ }{
+ {
+ epl: EndpointList{
+ API: "https://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://example.com/3",
+ },
+ err: nil,
+ },
+ {
+ epl: EndpointList{
+ API: "http://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://example.com/3",
+ },
+ err: errors.New("API scheme: 'http', is not equal to authorization scheme: 'https'"),
+ },
+ {
+ epl: EndpointList{
+ API: "https://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "ftp://example.com/3",
+ },
+ err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"),
+ },
+ {
+ epl: EndpointList{
+ API: "https://malicious.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://example.com/3",
+ },
+ err: errors.New("API host: 'malicious.com', is not equal to authorization host: 'example.com'"),
+ },
+ {
+ epl: EndpointList{
+ API: "https://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://malicious.com/3",
+ },
+ err: errors.New("API host: 'example.com', is not equal to token host: 'malicious.com'"),
+ },
+ {
+ epl: EndpointList{
+ API: "https://example.com/1",
+ Authorization: "https://malicious.com/2",
+ Token: "https://example.com/3",
+ },
+ err: errors.New("API host: 'example.com', is not equal to authorization host: 'malicious.com'"),
+ },
+ }
+
+ for _, tc := range testCases {
+ ep := &Endpoints{
+ API: EndpointsVersions{
+ V3: tc.epl,
+ },
+ }
+ // Update the handler
+ hs.SetHandler(http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+
+ jsonStr, err := json.Marshal(ep)
+ if err != nil {
+ t.Fatalf("failed to marshal JSON for test case: %v, err: %v", tc, err)
+ }
+
+ fmt.Fprintln(w, string(jsonStr))
+
+ }))
+ gotEP, err := APIGetEndpoints(s.URL, c)
+ if getErrorMsg(err) != getErrorMsg(tc.err) {
+ t.Fatalf("Errors not equal, want err: %v, got: %v", tc.err, err)
+ }
+ // The error was not nil, continue because endpoints should not be compared
+ if tc.err != nil {
+ continue
+ }
+ if ep == nil {
+ t.Fatalf("No test case endpoints")
+ }
+ if gotEP == nil {
+ t.Fatalf("Got no endpoints for nil error")
+ }
+ // if no error then the endpoints should be equal
+ if !compareEndpoints(*ep, *gotEP) {
+ t.Fatalf("Endpoints are not equal, got: %v, want: %v", gotEP, ep)
+ }
+ }
+}
diff --git a/internal/server/base.go b/internal/server/base.go
index 55fff09..87bb488 100644
--- a/internal/server/base.go
+++ b/internal/server/base.go
@@ -19,7 +19,7 @@ type Base struct {
}
func (b *Base) InitializeEndpoints() error {
- ep, err := APIGetEndpoints(b.URL)
+ ep, err := APIGetEndpoints(b.URL, b.httpClient)
if err != nil {
return err
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 00324a2..c68916e 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -33,12 +33,14 @@ type EndpointList struct {
Token string `json:"token_endpoint"`
}
+type EndpointsVersions struct {
+ V2 EndpointList `json:"http://eduvpn.org/api#2"`
+ V3 EndpointList `json:"http://eduvpn.org/api#3"`
+}
+
// Endpoints defines the json format for /.well-known/vpn-user-portal".
type Endpoints struct {
- API struct {
- V2 EndpointList `json:"http://eduvpn.org/api#2"`
- V3 EndpointList `json:"http://eduvpn.org/api#3"`
- } `json:"api"`
+ API EndpointsVersions `json:"api"`
V string `json:"v"`
}
diff --git a/internal/test/handler.go b/internal/test/handler.go
new file mode 100644
index 0000000..5c02629
--- /dev/null
+++ b/internal/test/handler.go
@@ -0,0 +1,25 @@
+package test
+
+import (
+ "net/http"
+ "sync"
+)
+
+// HandlerSet is a struct with a mutex that allows us to swap handlers while a test server is running
+type HandlerSet struct {
+ mu sync.Mutex
+ handler http.Handler
+}
+
+func (hs *HandlerSet) SetHandler(handler http.Handler) {
+ hs.mu.Lock()
+ hs.handler = handler
+ hs.mu.Unlock()
+}
+
+func (hs *HandlerSet) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ hs.mu.Lock()
+ handler := hs.handler
+ hs.mu.Unlock()
+ handler.ServeHTTP(w, r)
+}