summaryrefslogtreecommitdiff
path: root/internal/server/api.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/api.go')
-rw-r--r--internal/server/api.go51
1 files changed, 42 insertions, 9 deletions
diff --git a/internal/server/api.go b/internal/server/api.go
index d65e923..bede643 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -12,25 +12,58 @@ import (
"github.com/go-errors/errors"
)
-func APIGetEndpoints(baseURL string) (*Endpoints, error) {
- u, err := url.Parse(baseURL)
+func validateEndpoints(endpoints Endpoints) error {
+ v3 := endpoints.API.V3
+ pAPI, err := url.Parse(v3.API)
if err != nil {
- return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
+ return errors.WrapPrefix(err, "failed to parse API endpoint", 0)
+ }
+ pAuth, err := url.Parse(v3.Authorization)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed to parse API authorization endpoint", 0)
+ }
+ pToken, err := url.Parse(v3.Token)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed to parse API token endpoint", 0)
+ }
+ if pAPI.Scheme != pAuth.Scheme {
+ return errors.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme)
+ }
+ if pAPI.Scheme != pToken.Scheme {
+ return errors.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme)
+ }
+ if pAPI.Host != pAuth.Host {
+ return errors.Errorf("API host: '%v', is not equal to authorization host: '%v'", pAPI.Host, pAuth.Host)
}
+ if pAPI.Host != pToken.Host {
+ return errors.Errorf("API host: '%v', is not equal to token host: '%v'", pAPI.Host, pToken.Host)
+ }
+ return nil
+}
- u.Path = path.Join(u.Path, "/.well-known/vpn-user-portal")
- c := httpw.NewClient()
- _, body, err := c.Get(u.String())
+func APIGetEndpoints(baseURL string, client *httpw.Client) (*Endpoints, error) {
+ uStr, err := httpw.JoinURLPath(baseURL, "/.well-known/vpn-user-portal")
+ if err != nil {
+ return nil, err
+ }
+ if client == nil {
+ client = httpw.NewClient()
+ }
+ _, body, err := client.Get(uStr)
if err != nil {
return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
- ep := &Endpoints{}
- if err = json.Unmarshal(body, ep); err != nil {
+ ep := Endpoints{}
+ if err = json.Unmarshal(body, &ep); err != nil {
return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
+ err = validateEndpoints(ep)
+ if err != nil {
+ return nil, err
+ }
- return ep, nil
+ return &ep, nil
}
func apiAuthorized(