summaryrefslogtreecommitdiff
path: root/internal/server/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/api')
-rw-r--r--internal/server/api/api.go217
-rw-r--r--internal/server/api/api_test.go150
2 files changed, 0 insertions, 367 deletions
diff --git a/internal/server/api/api.go b/internal/server/api/api.go
deleted file mode 100644
index 9ad6f2d..0000000
--- a/internal/server/api/api.go
+++ /dev/null
@@ -1,217 +0,0 @@
-package api
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "net/url"
- "path"
- "time"
-
- httpw "github.com/eduvpn/eduvpn-common/internal/http"
- "github.com/eduvpn/eduvpn-common/internal/log"
- "github.com/eduvpn/eduvpn-common/internal/oauth"
- "github.com/eduvpn/eduvpn-common/internal/server/base"
- "github.com/eduvpn/eduvpn-common/internal/server/endpoints"
- "github.com/eduvpn/eduvpn-common/internal/server/profile"
- "github.com/go-errors/errors"
-)
-
-func Endpoints(ctx context.Context, b *base.Base) error {
- uStr, err := httpw.JoinURLPath(b.URL, "/.well-known/vpn-user-portal")
- if err != nil {
- return err
- }
- if b.HTTPClient == nil {
- b.HTTPClient = httpw.NewClient()
- }
- _, body, err := b.HTTPClient.Get(ctx, uStr)
- if err != nil {
- return errors.WrapPrefix(err, "failed getting server endpoints", 0)
- }
-
- ep := endpoints.Endpoints{}
- if err = json.Unmarshal(body, &ep); err != nil {
- return errors.WrapPrefix(err, "failed getting server endpoints", 0)
- }
- err = ep.Validate()
- if err != nil {
- return err
- }
-
- b.Endpoints = ep
- return nil
-}
-
-func authorized(
- ctx context.Context,
- b *base.Base,
- oauth *oauth.OAuth,
- method string,
- endpoint string,
- opts *httpw.OptionalParams,
-) (http.Header, []byte, error) {
- // Ensure optional is not nil as we will fill it with headers
- if opts == nil {
- opts = &httpw.OptionalParams{}
- }
- errorMessage := "failed API authorized"
-
- // Join the paths
- u, err := url.Parse(b.Endpoints.API.V3.API)
- if err != nil {
- return nil, nil, errors.WrapPrefix(err, errorMessage, 0)
- }
- u.Path = path.Join(u.Path, endpoint)
-
- // Make sure the tokens are valid, this will return an error if re-login is needed
- t, err := oauth.AccessToken(ctx)
- if err != nil {
- return nil, nil, errors.WrapPrefix(err, errorMessage, 0)
- }
-
- key := "Authorization"
- val := fmt.Sprintf("Bearer %s", t)
- if opts.Headers != nil {
- opts.Headers.Add(key, val)
- } else {
- opts.Headers = http.Header{key: {val}}
- }
-
- // Create a client if it doesn't exist
- if b.HTTPClient == nil {
- b.HTTPClient = httpw.NewClient()
- }
- return b.HTTPClient.Do(ctx, method, u.String(), opts)
-}
-
-func authorizedRetry(
- ctx context.Context,
- b *base.Base,
- auth *oauth.OAuth,
- method string,
- endpoint string,
- opts *httpw.OptionalParams,
-) (http.Header, []byte, error) {
- h, body, err := authorized(ctx, b, auth, method, endpoint, opts)
- if err == nil {
- return h, body, nil
- }
-
- statErr := &httpw.StatusError{}
- // Only retry authorized if we get an HTTP 401
- if errors.As(err, &statErr) && statErr.Status == 401 {
- log.Logger.Debugf("Got a 401 error after HTTP method: %s, endpoint: %s. Marking token as expired...", method, endpoint)
- // Mark the token as expired and retry, so we trigger the refresh flow
- auth.SetTokenExpired()
- h, body, err = authorized(ctx, b, auth, method, endpoint, opts)
- }
- return h, body, err
-}
-
-func Info(ctx context.Context, b *base.Base, auth *oauth.OAuth) error {
- _, body, err := authorizedRetry(ctx, b, auth, http.MethodGet, "/info", nil)
- if err != nil {
- return err
- }
- profiles := profile.Info{}
- if err = json.Unmarshal(body, &profiles); err != nil {
- return errors.WrapPrefix(err, "failed API /info", 0)
- }
-
- // Store the profiles and make sure that the current profile is not overwritten
- prev := b.Profiles.Current
- b.Profiles = profiles
- b.Profiles.Current = prev
- return nil
-}
-
-// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1
-func boolToYesNo(preferTCP bool) string {
- if preferTCP {
- return "yes"
- }
- return "no"
-}
-
-func ConnectWireguard(
- ctx context.Context,
- b *base.Base,
- auth *oauth.OAuth,
- profileID string,
- pubkey string,
- preferTCP bool,
- openVPNSupport bool,
-) (string, string, time.Time, error) {
- hdrs := http.Header{
- "content-type": {"application/x-www-form-urlencoded"},
- "accept": {"application/x-wireguard-profile"},
- }
-
- // This profile also supports OpenVPN
- // Indicate that we also accept OpenVPN profiles
- if openVPNSupport {
- hdrs.Add("accept", "application/x-openvpn-profile")
- }
-
- vals := url.Values{
- "profile_id": {profileID},
- "public_key": {pubkey},
- "prefer_tcp": {boolToYesNo(preferTCP)},
- }
- h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect",
- &httpw.OptionalParams{Headers: hdrs, Body: vals})
- if err != nil {
- return "", "", time.Time{}, err
- }
-
- exp := h.Get("expires")
-
- expTime, err := http.ParseTime(exp)
- if err != nil {
- return "", "", time.Time{}, errors.WrapPrefix(err, "failed obtaining a WireGuard configuration", 0)
- }
-
- contentH := h.Get("content-type")
- content := "openvpn"
- if contentH == "application/x-wireguard-profile" {
- content = "wireguard"
- }
-
- return string(body), content, expTime, nil
-}
-
-func ConnectOpenVPN(ctx context.Context, b *base.Base, auth *oauth.OAuth, profileID string, preferTCP bool) (string, time.Time, error) {
- hdrs := http.Header{
- "content-type": {"application/x-www-form-urlencoded"},
- "accept": {"application/x-openvpn-profile"},
- }
-
- vals := url.Values{
- "profile_id": {profileID},
- "prefer_tcp": {boolToYesNo(preferTCP)},
- }
-
- h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect",
- &httpw.OptionalParams{Headers: hdrs, Body: vals})
- if err != nil {
- return "", time.Time{}, err
- }
-
- expH := h.Get("expires")
- expT, err := http.ParseTime(expH)
- if err != nil {
- return "", time.Time{}, errors.WrapPrefix(err, "failed obtaining an OpenVPN configuration", 0)
- }
-
- return string(body), expT, nil
-}
-
-// Disconnect disconnects the VPN using the API.
-func Disconnect(ctx context.Context, b *base.Base, auth *oauth.OAuth) error {
- // The timeout is a bit lower here such that this does not take a too long time for disconnecting
- // Clients may wish to retry this
- _, _, err := authorized(ctx, b, auth, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second})
- return err
-}
diff --git a/internal/server/api/api_test.go b/internal/server/api/api_test.go
deleted file mode 100644
index 2fea4c6..0000000
--- a/internal/server/api/api_test.go
+++ /dev/null
@@ -1,150 +0,0 @@
-package api
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "testing"
-
- "github.com/eduvpn/eduvpn-common/internal/server/base"
- "github.com/eduvpn/eduvpn-common/internal/server/endpoints"
- "github.com/eduvpn/eduvpn-common/internal/test"
- "github.com/go-errors/errors"
-)
-
-func getErrorMsg(err error) string {
- if err == nil {
- return ""
- }
- return err.Error()
-}
-
-func compareEndpoints(ep1 endpoints.Endpoints, ep2 endpoints.Endpoints) bool {
- v3_1 := ep1.API.V3
- v3_2 := ep2.API.V3
- return v3_1.API == v3_2.API && v3_1.Authorization == v3_2.Authorization && v3_1.Token == v3_2.Token
-}
-
-func Test_APIGetEndpoints(t *testing.T) {
- handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- fmt.Fprintln(w, "Hello!")
- })
- hs := &test.HandlerSet{}
- hs.SetHandler(handler)
- s := test.NewServer(hs)
- defer s.Close()
-
- c, err := s.Client()
- if err != nil {
- t.Fatalf("failed to get client for test server endpoints: %v", err)
- }
-
- testCases := []struct {
- epl endpoints.List
- err error
- }{
- {
- epl: endpoints.List{
- API: "https://example.com/1",
- Authorization: "https://example.com/2",
- Token: "https://example.com/3",
- },
- err: nil,
- },
- {
- epl: endpoints.List{
- API: "https://example.com/1",
- Authorization: "http://example.com/2",
- Token: "http://example.com/3",
- },
- err: errors.New("API scheme: 'https', is not equal to authorization scheme: 'http'"),
- },
- {
- epl: endpoints.List{
- API: "https://example.com/1",
- Authorization: "https://example.com/2",
- Token: "ftp://example.com/3",
- },
- err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"),
- },
- {
- epl: endpoints.List{
- API: "https://malicious.com/1",
- Authorization: "https://example.com/2",
- Token: "https://example.com/3",
- },
- err: errors.New("API host: 'malicious.com', is not equal to authorization host: 'example.com'"),
- },
- {
- epl: endpoints.List{
- API: "https://example.com/1",
- Authorization: "https://example.com/2",
- Token: "https://malicious.com/3",
- },
- err: errors.New("API host: 'example.com', is not equal to token host: 'malicious.com'"),
- },
- {
- epl: endpoints.List{
- API: "https://example.com/1",
- Authorization: "https://malicious.com/2",
- Token: "https://example.com/3",
- },
- err: errors.New("API host: 'example.com', is not equal to authorization host: 'malicious.com'"),
- },
- {
- epl: endpoints.List{
- API: "https://example.com/1",
- Authorization: "https://example.com/2",
- Token: "ftp://example.com/3",
- },
- err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"),
- },
- {
- epl: endpoints.List{
- API: "https://example.com/1",
- Authorization: "ftp://example.com/2",
- Token: "https://example.com/3",
- },
- err: errors.New("API scheme: 'https', is not equal to authorization scheme: 'ftp'"),
- },
- }
-
- for _, tc := range testCases {
- ep := &endpoints.Endpoints{
- API: endpoints.Versions{
- V3: tc.epl,
- },
- }
- // Update the handler
- hs.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "application/json")
-
- jsonStr, err := json.Marshal(ep)
- if err != nil {
- t.Fatalf("failed to marshal JSON for test case: %v, err: %v", tc, err)
- }
-
- fmt.Fprintln(w, string(jsonStr))
- }))
- b := &base.Base{
- URL: s.URL,
- HTTPClient: c,
- }
- err = Endpoints(context.Background(), b)
- if getErrorMsg(err) != getErrorMsg(tc.err) {
- t.Fatalf("Errors not equal, want err: %v, got: %v", tc.err, err)
- }
- // The error was not nil, continue because endpoints should not be compared
- if tc.err != nil {
- continue
- }
- if ep == nil {
- t.Fatalf("No test case endpoints")
- }
- // if no error then the endpoints should be equal
- if !compareEndpoints(*ep, b.Endpoints) {
- t.Fatalf("Endpoints are not equal, got: %v, want: %v", b.Endpoints, ep)
- }
- }
-}