diff options
Diffstat (limited to 'internal/server/api')
| -rw-r--r-- | internal/server/api/api.go | 217 | ||||
| -rw-r--r-- | internal/server/api/api_test.go | 134 |
2 files changed, 351 insertions, 0 deletions
diff --git a/internal/server/api/api.go b/internal/server/api/api.go new file mode 100644 index 0000000..9ad6f2d --- /dev/null +++ b/internal/server/api/api.go @@ -0,0 +1,217 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "path" + "time" + + httpw "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/internal/log" + "github.com/eduvpn/eduvpn-common/internal/oauth" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/server/endpoints" + "github.com/eduvpn/eduvpn-common/internal/server/profile" + "github.com/go-errors/errors" +) + +func Endpoints(ctx context.Context, b *base.Base) error { + uStr, err := httpw.JoinURLPath(b.URL, "/.well-known/vpn-user-portal") + if err != nil { + return err + } + if b.HTTPClient == nil { + b.HTTPClient = httpw.NewClient() + } + _, body, err := b.HTTPClient.Get(ctx, uStr) + if err != nil { + return errors.WrapPrefix(err, "failed getting server endpoints", 0) + } + + ep := endpoints.Endpoints{} + if err = json.Unmarshal(body, &ep); err != nil { + return errors.WrapPrefix(err, "failed getting server endpoints", 0) + } + err = ep.Validate() + if err != nil { + return err + } + + b.Endpoints = ep + return nil +} + +func authorized( + ctx context.Context, + b *base.Base, + oauth *oauth.OAuth, + method string, + endpoint string, + opts *httpw.OptionalParams, +) (http.Header, []byte, error) { + // Ensure optional is not nil as we will fill it with headers + if opts == nil { + opts = &httpw.OptionalParams{} + } + errorMessage := "failed API authorized" + + // Join the paths + u, err := url.Parse(b.Endpoints.API.V3.API) + if err != nil { + return nil, nil, errors.WrapPrefix(err, errorMessage, 0) + } + u.Path = path.Join(u.Path, endpoint) + + // Make sure the tokens are valid, this will return an error if re-login is needed + t, err := oauth.AccessToken(ctx) + if err != nil { + return nil, nil, errors.WrapPrefix(err, errorMessage, 0) + } + + key := "Authorization" + val := fmt.Sprintf("Bearer %s", t) + if opts.Headers != nil { + opts.Headers.Add(key, val) + } else { + opts.Headers = http.Header{key: {val}} + } + + // Create a client if it doesn't exist + if b.HTTPClient == nil { + b.HTTPClient = httpw.NewClient() + } + return b.HTTPClient.Do(ctx, method, u.String(), opts) +} + +func authorizedRetry( + ctx context.Context, + b *base.Base, + auth *oauth.OAuth, + method string, + endpoint string, + opts *httpw.OptionalParams, +) (http.Header, []byte, error) { + h, body, err := authorized(ctx, b, auth, method, endpoint, opts) + if err == nil { + return h, body, nil + } + + statErr := &httpw.StatusError{} + // Only retry authorized if we get an HTTP 401 + if errors.As(err, &statErr) && statErr.Status == 401 { + log.Logger.Debugf("Got a 401 error after HTTP method: %s, endpoint: %s. Marking token as expired...", method, endpoint) + // Mark the token as expired and retry, so we trigger the refresh flow + auth.SetTokenExpired() + h, body, err = authorized(ctx, b, auth, method, endpoint, opts) + } + return h, body, err +} + +func Info(ctx context.Context, b *base.Base, auth *oauth.OAuth) error { + _, body, err := authorizedRetry(ctx, b, auth, http.MethodGet, "/info", nil) + if err != nil { + return err + } + profiles := profile.Info{} + if err = json.Unmarshal(body, &profiles); err != nil { + return errors.WrapPrefix(err, "failed API /info", 0) + } + + // Store the profiles and make sure that the current profile is not overwritten + prev := b.Profiles.Current + b.Profiles = profiles + b.Profiles.Current = prev + return nil +} + +// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 +func boolToYesNo(preferTCP bool) string { + if preferTCP { + return "yes" + } + return "no" +} + +func ConnectWireguard( + ctx context.Context, + b *base.Base, + auth *oauth.OAuth, + profileID string, + pubkey string, + preferTCP bool, + openVPNSupport bool, +) (string, string, time.Time, error) { + hdrs := http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + "accept": {"application/x-wireguard-profile"}, + } + + // This profile also supports OpenVPN + // Indicate that we also accept OpenVPN profiles + if openVPNSupport { + hdrs.Add("accept", "application/x-openvpn-profile") + } + + vals := url.Values{ + "profile_id": {profileID}, + "public_key": {pubkey}, + "prefer_tcp": {boolToYesNo(preferTCP)}, + } + h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect", + &httpw.OptionalParams{Headers: hdrs, Body: vals}) + if err != nil { + return "", "", time.Time{}, err + } + + exp := h.Get("expires") + + expTime, err := http.ParseTime(exp) + if err != nil { + return "", "", time.Time{}, errors.WrapPrefix(err, "failed obtaining a WireGuard configuration", 0) + } + + contentH := h.Get("content-type") + content := "openvpn" + if contentH == "application/x-wireguard-profile" { + content = "wireguard" + } + + return string(body), content, expTime, nil +} + +func ConnectOpenVPN(ctx context.Context, b *base.Base, auth *oauth.OAuth, profileID string, preferTCP bool) (string, time.Time, error) { + hdrs := http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + "accept": {"application/x-openvpn-profile"}, + } + + vals := url.Values{ + "profile_id": {profileID}, + "prefer_tcp": {boolToYesNo(preferTCP)}, + } + + h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect", + &httpw.OptionalParams{Headers: hdrs, Body: vals}) + if err != nil { + return "", time.Time{}, err + } + + expH := h.Get("expires") + expT, err := http.ParseTime(expH) + if err != nil { + return "", time.Time{}, errors.WrapPrefix(err, "failed obtaining an OpenVPN configuration", 0) + } + + return string(body), expT, nil +} + +// Disconnect disconnects the VPN using the API. +func Disconnect(ctx context.Context, b *base.Base, auth *oauth.OAuth) error { + // The timeout is a bit lower here such that this does not take a too long time for disconnecting + // Clients may wish to retry this + _, _, err := authorized(ctx, b, auth, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second}) + return err +} diff --git a/internal/server/api/api_test.go b/internal/server/api/api_test.go new file mode 100644 index 0000000..7509a30 --- /dev/null +++ b/internal/server/api/api_test.go @@ -0,0 +1,134 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/server/endpoints" + "github.com/eduvpn/eduvpn-common/internal/test" + "github.com/go-errors/errors" +) + +func getErrorMsg(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func compareEndpoints(ep1 endpoints.Endpoints, ep2 endpoints.Endpoints) bool { + v3_1 := ep1.API.V3 + v3_2 := ep2.API.V3 + return v3_1.API == v3_2.API && v3_1.Authorization == v3_2.Authorization && v3_1.Token == v3_2.Token +} + +func Test_APIGetEndpoints(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello!") + }) + hs := &test.HandlerSet{} + hs.SetHandler(handler) + s := test.NewServer(hs) + defer s.Close() + + c, err := s.Client() + if err != nil { + t.Fatalf("failed to get client for test server endpoints: %v", err) + } + + testCases := []struct { + epl endpoints.List + err error + }{ + { + epl: endpoints.List{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: nil, + }, + { + epl: endpoints.List{ + API: "http://example.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API scheme: 'http', is not equal to authorization scheme: 'https'"), + }, + { + epl: endpoints.List{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "ftp://example.com/3", + }, + err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"), + }, + { + epl: endpoints.List{ + API: "https://malicious.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API host: 'malicious.com', is not equal to authorization host: 'example.com'"), + }, + { + epl: endpoints.List{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "https://malicious.com/3", + }, + err: errors.New("API host: 'example.com', is not equal to token host: 'malicious.com'"), + }, + { + epl: endpoints.List{ + API: "https://example.com/1", + Authorization: "https://malicious.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API host: 'example.com', is not equal to authorization host: 'malicious.com'"), + }, + } + + for _, tc := range testCases { + ep := &endpoints.Endpoints{ + API: endpoints.Versions{ + V3: tc.epl, + }, + } + // Update the handler + hs.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + jsonStr, err := json.Marshal(ep) + if err != nil { + t.Fatalf("failed to marshal JSON for test case: %v, err: %v", tc, err) + } + + fmt.Fprintln(w, string(jsonStr)) + })) + b := &base.Base{ + URL: s.URL, + HTTPClient: c, + } + err = Endpoints(context.Background(), b) + if getErrorMsg(err) != getErrorMsg(tc.err) { + t.Fatalf("Errors not equal, want err: %v, got: %v", tc.err, err) + } + // The error was not nil, continue because endpoints should not be compared + if tc.err != nil { + continue + } + if ep == nil { + t.Fatalf("No test case endpoints") + } + // if no error then the endpoints should be equal + if !compareEndpoints(*ep, b.Endpoints) { + t.Fatalf("Endpoints are not equal, got: %v, want: %v", b.Endpoints, ep) + } + } +} |
