summaryrefslogtreecommitdiff
path: root/internal/server/api
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/api')
-rw-r--r--internal/server/api/api.go217
-rw-r--r--internal/server/api/api_test.go134
2 files changed, 351 insertions, 0 deletions
diff --git a/internal/server/api/api.go b/internal/server/api/api.go
new file mode 100644
index 0000000..9ad6f2d
--- /dev/null
+++ b/internal/server/api/api.go
@@ -0,0 +1,217 @@
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "path"
+ "time"
+
+ httpw "github.com/eduvpn/eduvpn-common/internal/http"
+ "github.com/eduvpn/eduvpn-common/internal/log"
+ "github.com/eduvpn/eduvpn-common/internal/oauth"
+ "github.com/eduvpn/eduvpn-common/internal/server/base"
+ "github.com/eduvpn/eduvpn-common/internal/server/endpoints"
+ "github.com/eduvpn/eduvpn-common/internal/server/profile"
+ "github.com/go-errors/errors"
+)
+
+func Endpoints(ctx context.Context, b *base.Base) error {
+ uStr, err := httpw.JoinURLPath(b.URL, "/.well-known/vpn-user-portal")
+ if err != nil {
+ return err
+ }
+ if b.HTTPClient == nil {
+ b.HTTPClient = httpw.NewClient()
+ }
+ _, body, err := b.HTTPClient.Get(ctx, uStr)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed getting server endpoints", 0)
+ }
+
+ ep := endpoints.Endpoints{}
+ if err = json.Unmarshal(body, &ep); err != nil {
+ return errors.WrapPrefix(err, "failed getting server endpoints", 0)
+ }
+ err = ep.Validate()
+ if err != nil {
+ return err
+ }
+
+ b.Endpoints = ep
+ return nil
+}
+
+func authorized(
+ ctx context.Context,
+ b *base.Base,
+ oauth *oauth.OAuth,
+ method string,
+ endpoint string,
+ opts *httpw.OptionalParams,
+) (http.Header, []byte, error) {
+ // Ensure optional is not nil as we will fill it with headers
+ if opts == nil {
+ opts = &httpw.OptionalParams{}
+ }
+ errorMessage := "failed API authorized"
+
+ // Join the paths
+ u, err := url.Parse(b.Endpoints.API.V3.API)
+ if err != nil {
+ return nil, nil, errors.WrapPrefix(err, errorMessage, 0)
+ }
+ u.Path = path.Join(u.Path, endpoint)
+
+ // Make sure the tokens are valid, this will return an error if re-login is needed
+ t, err := oauth.AccessToken(ctx)
+ if err != nil {
+ return nil, nil, errors.WrapPrefix(err, errorMessage, 0)
+ }
+
+ key := "Authorization"
+ val := fmt.Sprintf("Bearer %s", t)
+ if opts.Headers != nil {
+ opts.Headers.Add(key, val)
+ } else {
+ opts.Headers = http.Header{key: {val}}
+ }
+
+ // Create a client if it doesn't exist
+ if b.HTTPClient == nil {
+ b.HTTPClient = httpw.NewClient()
+ }
+ return b.HTTPClient.Do(ctx, method, u.String(), opts)
+}
+
+func authorizedRetry(
+ ctx context.Context,
+ b *base.Base,
+ auth *oauth.OAuth,
+ method string,
+ endpoint string,
+ opts *httpw.OptionalParams,
+) (http.Header, []byte, error) {
+ h, body, err := authorized(ctx, b, auth, method, endpoint, opts)
+ if err == nil {
+ return h, body, nil
+ }
+
+ statErr := &httpw.StatusError{}
+ // Only retry authorized if we get an HTTP 401
+ if errors.As(err, &statErr) && statErr.Status == 401 {
+ log.Logger.Debugf("Got a 401 error after HTTP method: %s, endpoint: %s. Marking token as expired...", method, endpoint)
+ // Mark the token as expired and retry, so we trigger the refresh flow
+ auth.SetTokenExpired()
+ h, body, err = authorized(ctx, b, auth, method, endpoint, opts)
+ }
+ return h, body, err
+}
+
+func Info(ctx context.Context, b *base.Base, auth *oauth.OAuth) error {
+ _, body, err := authorizedRetry(ctx, b, auth, http.MethodGet, "/info", nil)
+ if err != nil {
+ return err
+ }
+ profiles := profile.Info{}
+ if err = json.Unmarshal(body, &profiles); err != nil {
+ return errors.WrapPrefix(err, "failed API /info", 0)
+ }
+
+ // Store the profiles and make sure that the current profile is not overwritten
+ prev := b.Profiles.Current
+ b.Profiles = profiles
+ b.Profiles.Current = prev
+ return nil
+}
+
+// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1
+func boolToYesNo(preferTCP bool) string {
+ if preferTCP {
+ return "yes"
+ }
+ return "no"
+}
+
+func ConnectWireguard(
+ ctx context.Context,
+ b *base.Base,
+ auth *oauth.OAuth,
+ profileID string,
+ pubkey string,
+ preferTCP bool,
+ openVPNSupport bool,
+) (string, string, time.Time, error) {
+ hdrs := http.Header{
+ "content-type": {"application/x-www-form-urlencoded"},
+ "accept": {"application/x-wireguard-profile"},
+ }
+
+ // This profile also supports OpenVPN
+ // Indicate that we also accept OpenVPN profiles
+ if openVPNSupport {
+ hdrs.Add("accept", "application/x-openvpn-profile")
+ }
+
+ vals := url.Values{
+ "profile_id": {profileID},
+ "public_key": {pubkey},
+ "prefer_tcp": {boolToYesNo(preferTCP)},
+ }
+ h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect",
+ &httpw.OptionalParams{Headers: hdrs, Body: vals})
+ if err != nil {
+ return "", "", time.Time{}, err
+ }
+
+ exp := h.Get("expires")
+
+ expTime, err := http.ParseTime(exp)
+ if err != nil {
+ return "", "", time.Time{}, errors.WrapPrefix(err, "failed obtaining a WireGuard configuration", 0)
+ }
+
+ contentH := h.Get("content-type")
+ content := "openvpn"
+ if contentH == "application/x-wireguard-profile" {
+ content = "wireguard"
+ }
+
+ return string(body), content, expTime, nil
+}
+
+func ConnectOpenVPN(ctx context.Context, b *base.Base, auth *oauth.OAuth, profileID string, preferTCP bool) (string, time.Time, error) {
+ hdrs := http.Header{
+ "content-type": {"application/x-www-form-urlencoded"},
+ "accept": {"application/x-openvpn-profile"},
+ }
+
+ vals := url.Values{
+ "profile_id": {profileID},
+ "prefer_tcp": {boolToYesNo(preferTCP)},
+ }
+
+ h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect",
+ &httpw.OptionalParams{Headers: hdrs, Body: vals})
+ if err != nil {
+ return "", time.Time{}, err
+ }
+
+ expH := h.Get("expires")
+ expT, err := http.ParseTime(expH)
+ if err != nil {
+ return "", time.Time{}, errors.WrapPrefix(err, "failed obtaining an OpenVPN configuration", 0)
+ }
+
+ return string(body), expT, nil
+}
+
+// Disconnect disconnects the VPN using the API.
+func Disconnect(ctx context.Context, b *base.Base, auth *oauth.OAuth) error {
+ // The timeout is a bit lower here such that this does not take a too long time for disconnecting
+ // Clients may wish to retry this
+ _, _, err := authorized(ctx, b, auth, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second})
+ return err
+}
diff --git a/internal/server/api/api_test.go b/internal/server/api/api_test.go
new file mode 100644
index 0000000..7509a30
--- /dev/null
+++ b/internal/server/api/api_test.go
@@ -0,0 +1,134 @@
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "github.com/eduvpn/eduvpn-common/internal/server/base"
+ "github.com/eduvpn/eduvpn-common/internal/server/endpoints"
+ "github.com/eduvpn/eduvpn-common/internal/test"
+ "github.com/go-errors/errors"
+)
+
+func getErrorMsg(err error) string {
+ if err == nil {
+ return ""
+ }
+ return err.Error()
+}
+
+func compareEndpoints(ep1 endpoints.Endpoints, ep2 endpoints.Endpoints) bool {
+ v3_1 := ep1.API.V3
+ v3_2 := ep2.API.V3
+ return v3_1.API == v3_2.API && v3_1.Authorization == v3_2.Authorization && v3_1.Token == v3_2.Token
+}
+
+func Test_APIGetEndpoints(t *testing.T) {
+ handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello!")
+ })
+ hs := &test.HandlerSet{}
+ hs.SetHandler(handler)
+ s := test.NewServer(hs)
+ defer s.Close()
+
+ c, err := s.Client()
+ if err != nil {
+ t.Fatalf("failed to get client for test server endpoints: %v", err)
+ }
+
+ testCases := []struct {
+ epl endpoints.List
+ err error
+ }{
+ {
+ epl: endpoints.List{
+ API: "https://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://example.com/3",
+ },
+ err: nil,
+ },
+ {
+ epl: endpoints.List{
+ API: "http://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://example.com/3",
+ },
+ err: errors.New("API scheme: 'http', is not equal to authorization scheme: 'https'"),
+ },
+ {
+ epl: endpoints.List{
+ API: "https://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "ftp://example.com/3",
+ },
+ err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"),
+ },
+ {
+ epl: endpoints.List{
+ API: "https://malicious.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://example.com/3",
+ },
+ err: errors.New("API host: 'malicious.com', is not equal to authorization host: 'example.com'"),
+ },
+ {
+ epl: endpoints.List{
+ API: "https://example.com/1",
+ Authorization: "https://example.com/2",
+ Token: "https://malicious.com/3",
+ },
+ err: errors.New("API host: 'example.com', is not equal to token host: 'malicious.com'"),
+ },
+ {
+ epl: endpoints.List{
+ API: "https://example.com/1",
+ Authorization: "https://malicious.com/2",
+ Token: "https://example.com/3",
+ },
+ err: errors.New("API host: 'example.com', is not equal to authorization host: 'malicious.com'"),
+ },
+ }
+
+ for _, tc := range testCases {
+ ep := &endpoints.Endpoints{
+ API: endpoints.Versions{
+ V3: tc.epl,
+ },
+ }
+ // Update the handler
+ hs.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+
+ jsonStr, err := json.Marshal(ep)
+ if err != nil {
+ t.Fatalf("failed to marshal JSON for test case: %v, err: %v", tc, err)
+ }
+
+ fmt.Fprintln(w, string(jsonStr))
+ }))
+ b := &base.Base{
+ URL: s.URL,
+ HTTPClient: c,
+ }
+ err = Endpoints(context.Background(), b)
+ if getErrorMsg(err) != getErrorMsg(tc.err) {
+ t.Fatalf("Errors not equal, want err: %v, got: %v", tc.err, err)
+ }
+ // The error was not nil, continue because endpoints should not be compared
+ if tc.err != nil {
+ continue
+ }
+ if ep == nil {
+ t.Fatalf("No test case endpoints")
+ }
+ // if no error then the endpoints should be equal
+ if !compareEndpoints(*ep, b.Endpoints) {
+ t.Fatalf("Endpoints are not equal, got: %v, want: %v", b.Endpoints, ep)
+ }
+ }
+}