summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2023-02-16 15:48:52 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2023-02-16 15:52:05 +0100
commit807140ce43584e9612f7b5890b13d751247f8e6e (patch)
tree08e05fd79078f5093bc7aea68557b212bb5c1bfa /internal
parentf718788442682f87e2fd1b6067f6062bade52d52 (diff)
Server: Validate endpoints
This commit validates the server endpoints by checking the Host and scheme of each URL to check if they match eachother. This is to prevent further mixup attacks
Diffstat (limited to 'internal')
-rw-r--r--internal/server/api.go51
-rw-r--r--internal/server/api_test.go132
-rw-r--r--internal/server/base.go2
-rw-r--r--internal/server/server.go10
-rw-r--r--internal/test/handler.go25
5 files changed, 206 insertions, 14 deletions
diff --git a/internal/server/api.go b/internal/server/api.go
index d65e923..bede643 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -12,25 +12,58 @@ import (
"github.com/go-errors/errors"
)
-func APIGetEndpoints(baseURL string) (*Endpoints, error) {
- u, err := url.Parse(baseURL)
+func validateEndpoints(endpoints Endpoints) error {
+ v3 := endpoints.API.V3
+ pAPI, err := url.Parse(v3.API)
if err != nil {
- return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
+ return errors.WrapPrefix(err, "failed to parse API endpoint", 0)
+ }
+ pAuth, err := url.Parse(v3.Authorization)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed to parse API authorization endpoint", 0)
+ }
+ pToken, err := url.Parse(v3.Token)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed to parse API token endpoint", 0)
+ }
+ if pAPI.Scheme != pAuth.Scheme {
+ return errors.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme)
+ }
+ if pAPI.Scheme != pToken.Scheme {
+ return errors.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme)
+ }
+ if pAPI.Host != pAuth.Host {
+ return errors.Errorf("API host: '%v', is not equal to authorization host: '%v'", pAPI.Host, pAuth.Host)
}
+ if pAPI.Host != pToken.Host {
+ return errors.Errorf("API host: '%v', is not equal to token host: '%v'", pAPI.Host, pToken.Host)
+ }
+ return nil
+}
- u.Path = path.Join(u.Path, "/.well-known/vpn-user-portal")
- c := httpw.NewClient()
- _, body, err := c.Get(u.String())
+func APIGetEndpoints(baseURL string, client *httpw.Client) (*Endpoints, error) {
+ uStr, err := httpw.JoinURLPath(baseURL, "/.well-known/vpn-user-portal")
+ if err != nil {
+ return nil, err
+ }
+ if client == nil {
+ client = httpw.NewClient()
+ }
+ _, body, err := client.Get(uStr)
if err != nil {
return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
- ep := &Endpoints{}
- if err = json.Unmarshal(body, ep); err != nil {
+ ep := Endpoints{}
+ if err = json.Unmarshal(body, &ep); err != nil {
return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
+ err = validateEndpoints(ep)
+ if err != nil {
+ return nil, err
+ }
- return ep, nil
+ return &ep, nil
}
func apiAuthorized(
diff --git a/internal/server/api_test.go b/internal/server/api_test.go
new file mode 100644
index 0000000..00fba3d
--- /dev/null
+++ b/internal/server/api_test.go
@@ -0,0 +1,132 @@
+package server
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "github.com/eduvpn/eduvpn-common/internal/test"
+ "github.com/go-errors/errors"
+)
+
+
+func getErrorMsg(err error) string {
+ if err == nil {
+ return ""
+ }
+ return err.Error()
+}
+
+func compareEndpoints(ep1 Endpoints, ep2 Endpoints) bool {
+ v3_1 := ep1.API.V3
+ v3_2 := ep2.API.V3
+ return v3_1.API == v3_2.API && v3_1.Authorization == v3_2.Authorization && v3_1.Token == v3_2.Token
+}
+
+func Test_APIGetEndpoints(t *testing.T) {
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello!")
+ })
+ hs := &test.HandlerSet{}
+ hs.SetHandler(handler)
+ s := test.NewServer(hs)
+ defer s.Close()
+
+ c, err := s.Client()
+ if err != nil {
+ t.Fatalf("failed to get client for test server endpoints: %v", err)
+ }
+
+ testCases := []struct {
+ epl EndpointList
+ err error
+ }{
+ {
+ epl: EndpointList{
+ API: "https://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://example.com/3",
+ },
+ err: nil,
+ },
+ {
+ epl: EndpointList{
+ API: "http://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://example.com/3",
+ },
+ err: errors.New("API scheme: 'http', is not equal to authorization scheme: 'https'"),
+ },
+ {
+ epl: EndpointList{
+ API: "https://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "ftp://example.com/3",
+ },
+ err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"),
+ },
+ {
+ epl: EndpointList{
+ API: "https://malicious.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://example.com/3",
+ },
+ err: errors.New("API host: 'malicious.com', is not equal to authorization host: 'example.com'"),
+ },
+ {
+ epl: EndpointList{
+ API: "https://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://malicious.com/3",
+ },
+ err: errors.New("API host: 'example.com', is not equal to token host: 'malicious.com'"),
+ },
+ {
+ epl: EndpointList{
+ API: "https://example.com/1",
+ Authorization: "https://malicious.com/2",
+ Token: "https://example.com/3",
+ },
+ err: errors.New("API host: 'example.com', is not equal to authorization host: 'malicious.com'"),
+ },
+ }
+
+ for _, tc := range testCases {
+ ep := &Endpoints{
+ API: EndpointsVersions{
+ V3: tc.epl,
+ },
+ }
+ // Update the handler
+ hs.SetHandler(http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+
+ jsonStr, err := json.Marshal(ep)
+ if err != nil {
+ t.Fatalf("failed to marshal JSON for test case: %v, err: %v", tc, err)
+ }
+
+ fmt.Fprintln(w, string(jsonStr))
+
+ }))
+ gotEP, err := APIGetEndpoints(s.URL, c)
+ if getErrorMsg(err) != getErrorMsg(tc.err) {
+ t.Fatalf("Errors not equal, want err: %v, got: %v", tc.err, err)
+ }
+ // The error was not nil, continue because endpoints should not be compared
+ if tc.err != nil {
+ continue
+ }
+ if ep == nil {
+ t.Fatalf("No test case endpoints")
+ }
+ if gotEP == nil {
+ t.Fatalf("Got no endpoints for nil error")
+ }
+ // if no error then the endpoints should be equal
+ if !compareEndpoints(*ep, *gotEP) {
+ t.Fatalf("Endpoints are not equal, got: %v, want: %v", gotEP, ep)
+ }
+ }
+}
diff --git a/internal/server/base.go b/internal/server/base.go
index 55fff09..87bb488 100644
--- a/internal/server/base.go
+++ b/internal/server/base.go
@@ -19,7 +19,7 @@ type Base struct {
}
func (b *Base) InitializeEndpoints() error {
- ep, err := APIGetEndpoints(b.URL)
+ ep, err := APIGetEndpoints(b.URL, b.httpClient)
if err != nil {
return err
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 00324a2..c68916e 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -33,12 +33,14 @@ type EndpointList struct {
Token string `json:"token_endpoint"`
}
+type EndpointsVersions struct {
+ V2 EndpointList `json:"http://eduvpn.org/api#2"`
+ V3 EndpointList `json:"http://eduvpn.org/api#3"`
+}
+
// Endpoints defines the json format for /.well-known/vpn-user-portal".
type Endpoints struct {
- API struct {
- V2 EndpointList `json:"http://eduvpn.org/api#2"`
- V3 EndpointList `json:"http://eduvpn.org/api#3"`
- } `json:"api"`
+ API EndpointsVersions `json:"api"`
V string `json:"v"`
}
diff --git a/internal/test/handler.go b/internal/test/handler.go
new file mode 100644
index 0000000..5c02629
--- /dev/null
+++ b/internal/test/handler.go
@@ -0,0 +1,25 @@
+package test
+
+import (
+ "net/http"
+ "sync"
+)
+
+// HandlerSet is a struct with a mutex that allows us to swap handlers while a test server is running
+type HandlerSet struct {
+ mu sync.Mutex
+ handler http.Handler
+}
+
+func (hs *HandlerSet) SetHandler(handler http.Handler) {
+ hs.mu.Lock()
+ hs.handler = handler
+ hs.mu.Unlock()
+}
+
+func (hs *HandlerSet) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ hs.mu.Lock()
+ handler := hs.handler
+ hs.mu.Unlock()
+ handler.ServeHTTP(w, r)
+}