summaryrefslogtreecommitdiff
path: root/internal/eduvpnapi
diff options
context:
space:
mode:
Diffstat (limited to 'internal/eduvpnapi')
-rw-r--r--internal/eduvpnapi/cache.go67
-rw-r--r--internal/eduvpnapi/eduvpnapi.go395
-rw-r--r--internal/eduvpnapi/eduvpnapi_test.go513
-rw-r--r--internal/eduvpnapi/endpoints/endpoints.go62
-rw-r--r--internal/eduvpnapi/profiles/profiles.go119
-rw-r--r--internal/eduvpnapi/redirect.go28
6 files changed, 1184 insertions, 0 deletions
diff --git a/internal/eduvpnapi/cache.go b/internal/eduvpnapi/cache.go
new file mode 100644
index 0000000..d3e2b77
--- /dev/null
+++ b/internal/eduvpnapi/cache.go
@@ -0,0 +1,67 @@
+package eduvpnapi
+
+import (
+ "context"
+ "net/http"
+ "sync"
+ "time"
+
+ "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/endpoints"
+)
+
+// EndpointCache is a struct that caches well-known API endpoints
+type EndpointCache struct {
+ lastUpdate map[string]time.Time
+ lastEP map[string]*endpoints.Endpoints
+ mu sync.Mutex
+}
+
+// Get returns a cached or fresh endpoint cache copy
+func (ec *EndpointCache) Get(ctx context.Context, wk string, transport http.RoundTripper) (*endpoints.Endpoints, error) {
+ ec.mu.Lock()
+ defer ec.mu.Unlock()
+
+ // get the last update time
+ lu := time.Time{}
+ if v, ok := ec.lastUpdate[wk]; ok {
+ lu = v
+ }
+
+ // if not 10 minutes have passed, return cached copy
+ if !lu.IsZero() && !time.Now().After(lu.Add(10*time.Minute)) {
+ v, ok := ec.lastEP[wk]
+ if ok {
+ return v, nil
+ }
+ }
+
+ // get fresh API endpoints
+ ep, err := getEndpoints(ctx, wk, transport)
+ if err != nil {
+ return nil, err
+ }
+
+ // update endpoints
+ ec.lastUpdate[wk] = time.Now()
+ ec.lastEP[wk] = ep
+
+ return ep, nil
+}
+
+var (
+ epCache *EndpointCache
+ epCacheOnce sync.Once
+)
+
+// GetEndpointCache returns the global singleton endpoint cache
+// or creates one if it does not exist
+func GetEndpointCache() *EndpointCache {
+ epCacheOnce.Do(func() {
+ epCache = &EndpointCache{
+ lastUpdate: make(map[string]time.Time),
+ lastEP: make(map[string]*endpoints.Endpoints),
+ }
+ })
+
+ return epCache
+}
diff --git a/internal/eduvpnapi/eduvpnapi.go b/internal/eduvpnapi/eduvpnapi.go
new file mode 100644
index 0000000..62fe0d1
--- /dev/null
+++ b/internal/eduvpnapi/eduvpnapi.go
@@ -0,0 +1,395 @@
+// Package eduvpnapi implements version 3 of the eduVPN api: https://docs.eduvpn.org/server/v3/api.html
+package eduvpnapi
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "time"
+
+ "codeberg.org/jwijenbergh/eduoauth-go/v2"
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/endpoints"
+ "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/profiles"
+ "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap"
+ "codeberg.org/eduVPN/eduvpn-common/internal/wireguard"
+ "codeberg.org/eduVPN/eduvpn-common/types/protocol"
+ "codeberg.org/eduVPN/eduvpn-common/types/server"
+)
+
+// Callbacks is the API callback interface
+// It is used to trigger authorization and forward token updates
+type Callbacks interface {
+ // TriggerAuth is called when authorization should be triggered
+ TriggerAuth(context.Context, string, bool) (string, error)
+ // AuthDone is called when authorization has just completed
+ AuthDone(string, server.Type)
+ // TokensUpdates is called when tokens are updated
+ TokensUpdated(string, server.Type, eduoauth.Token)
+}
+
+// ServerData is the data for a server that is passed to the API struct
+type ServerData struct {
+ // ID is the identifier for the server
+ ID string
+ // Type is the type of server
+ Type server.Type
+ // BaseWK is the base well-known endpoint
+ BaseWK string
+ // BaseAuthWK is the base well-known endpoint for authorization. This is only different in case of secure internet
+ BaseAuthWK string
+ // ProcessAuth processes the OAuth authorization
+ ProcessAuth func(context.Context, string) (string, error)
+ // DisableAuthorize indicates whether or not new authorization requests should be disabled
+ DisableAuthorize bool
+ // transport is the HTTP transport, only used for testing currently
+ transport http.RoundTripper
+}
+
+// Transport returns the transport to be used for the server
+// By default it uses the transport from internal/httpwrap DefaultTransport
+func (s *ServerData) Transport() http.RoundTripper {
+ if s.transport == nil {
+ return httpwrap.DefaultTransport
+ }
+ return s.transport
+}
+
+// API is the top-level struct that each method is defined on
+type API struct {
+ cb Callbacks
+ // oauth is the oauth object
+ oauth *eduoauth.OAuth
+ // Data is the server data
+ Data ServerData
+}
+
+// NewAPI creates a new API object by creating an OAuth object
+func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, tokens *eduoauth.Token) (*API, error) {
+ cr := customRedirect(clientID)
+ // Construct OAuth
+
+ transp := sd.Transport()
+ post := true
+ // we do not support non-loopback clients with response_mode form_post
+ if cr != "" {
+ post = false
+ }
+ o := eduoauth.OAuth{
+ ClientID: clientID,
+ EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) {
+ ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, transp)
+ if err != nil {
+ return nil, err
+ }
+ return &eduoauth.EndpointResponse{
+ AuthorizationURL: ep.API.V3.Authorization,
+ TokenURL: ep.API.V3.Token,
+ }, nil
+ },
+ CustomRedirect: cr,
+ FormPost: post,
+ RedirectPath: "/callback",
+ TokensUpdated: func(tok eduoauth.Token) {
+ cb.TokensUpdated(sd.ID, sd.Type, tok)
+ },
+ Transport: transp,
+ UserAgent: httpwrap.UserAgent,
+ }
+
+ if tokens != nil {
+ o.UpdateTokens(*tokens)
+ }
+
+ api := &API{
+ cb: cb,
+ oauth: &o,
+ Data: sd,
+ }
+ err := api.authorize(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return api, nil
+}
+
+// ErrAuthorizeDisabled is returned when authorization is disabled but is needed to complete
+var ErrAuthorizeDisabled = errors.New("cannot authorize as re-authorization is disabled")
+
+func (a *API) authorize(ctx context.Context) (err error) {
+ _, err = a.oauth.AccessToken(ctx)
+ // already authorized
+ if err == nil {
+ return nil
+ }
+
+ // otherwise check if invalid tokens,
+ // if not then something else is wrong with the API
+ // return an error
+ tErr := &eduoauth.TokensInvalidError{}
+ if !errors.As(err, &tErr) {
+ return err
+ }
+
+ if a.Data.DisableAuthorize {
+ return ErrAuthorizeDisabled
+ }
+
+ defer func() {
+ if err == nil {
+ a.cb.AuthDone(a.Data.ID, a.Data.Type)
+ }
+ }()
+
+ scope := "config"
+ url, err := a.oauth.AuthURL(ctx, scope)
+ if err != nil {
+ return err
+ }
+ if a.Data.ProcessAuth != nil {
+ url, err = a.Data.ProcessAuth(ctx, url)
+ if err != nil {
+ return err
+ }
+ }
+ // We expect an uri if custom redirect is non empty
+ uri, err := a.cb.TriggerAuth(ctx, url, a.oauth.CustomRedirect != "")
+ if err != nil {
+ return err
+ }
+ // The uri is only given here if a custom redirect is done
+ err = a.oauth.Exchange(ctx, uri)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpwrap.OptionalParams) (http.Header, []byte, error) {
+ ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport())
+ if err != nil {
+ return nil, nil, err
+ }
+ u := ep.API.V3.API + endpoint
+
+ // TODO: Cache HTTP client?
+ httpC := httpwrap.NewClient(a.oauth.NewHTTPClient())
+ return httpC.Do(ctx, method, u, opts)
+}
+
+func (a *API) authorizedRetry(ctx context.Context, method string, endpoint string, opts *httpwrap.OptionalParams) (http.Header, []byte, error) {
+ h, body, err := a.authorized(ctx, method, endpoint, opts)
+ if err == nil {
+ return h, body, nil
+ }
+
+ statErr := &httpwrap.StatusError{}
+ // Only retry authorized if we get an HTTP 401
+ // TODO: Can the OAuth client handle this instead?
+ if errors.As(err, &statErr) && statErr.Status == 401 {
+ slog.Debug("Got a HTTP 401. Marking tokens as expired...", "HTTP method", method, "endpoint", endpoint)
+ // Mark the token as expired and retry, so we trigger the refresh flow
+ a.oauth.SetTokenExpired()
+ h, body, err = a.authorized(ctx, method, endpoint, opts)
+ }
+ // Tokens is invalid we need to renew and authorize again
+ tErr := &eduoauth.TokensInvalidError{}
+ if err != nil && errors.As(err, &tErr) {
+ // Mark the token as invalid and retry, so we trigger the authorization flow
+ a.oauth.SetTokenRenew()
+ slog.Debug("The tokens were invalid, trying again...")
+ if autherr := a.authorize(ctx); autherr != nil {
+ return nil, nil, autherr
+ }
+ return a.authorized(ctx, method, endpoint, opts)
+ }
+ return h, body, err
+}
+
+// Disconnect disconnects a client from the server by sending a /disconnect API call
+// This cleans up resources such as WireGuard IP allocation
+func (a *API) Disconnect(ctx context.Context) error {
+ _, _, err := a.authorized(ctx, http.MethodPost, "/disconnect", &httpwrap.OptionalParams{Timeout: 5 * time.Second})
+ return err
+}
+
+// Info does the /info API call
+func (a *API) Info(ctx context.Context) (*profiles.Info, error) {
+ _, body, err := a.authorizedRetry(ctx, http.MethodGet, "/info", nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed API /info: %w", err)
+ }
+ p := profiles.Info{}
+ if err = json.Unmarshal(body, &p); err != nil {
+ return nil, fmt.Errorf("failed API /info: %w", err)
+ }
+ return &p, nil
+}
+
+// ConnectData is the data that is returned when the /connect call completes without error
+type ConnectData struct {
+ // Configuration is the VPN configuration
+ Configuration string
+ // Protocol tells us what protocol it is, OpenVPN or WireGuard (proxied or not)
+ Protocol protocol.Protocol
+ // Expires tells us when this configuration expires
+ Expires time.Time
+}
+
+// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1
+func boolToYesNo(preferTCP bool) string {
+ if preferTCP {
+ return "yes"
+ }
+ return "no"
+}
+
+func protocolFromCT(ct string) (protocol.Protocol, error) {
+ switch ct {
+ case "application/x-wireguard-profile":
+ return protocol.WireGuard, nil
+ case "application/x-wireguard+tcp-profile":
+ return protocol.WireGuardProxy, nil
+ case "application/x-openvpn-profile":
+ return protocol.OpenVPN, nil
+ }
+ return protocol.Unknown, fmt.Errorf("invalid content type: %s", ct)
+}
+
+// ErrNoProtocols is returned when a connect call is given with an empty protocol slice
+var ErrNoProtocols = errors.New("no protocols supplied")
+
+// ErrUnknownProtocol is returned when the client in a connect gives an unknown protocol
+var ErrUnknownProtocol = errors.New("unknown protocol supplied")
+
+// Connect sends a /connect to an eduVPN server
+// `ctx` is the context used for cancellation
+// protos is the list of protocols supported and wanted by the client
+func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []protocol.Protocol, pTCP bool) (*ConnectData, error) {
+ hdrs := http.Header{
+ "content-type": {"application/x-www-form-urlencoded"},
+ }
+ uv := url.Values{
+ "profile_id": {prof.ID},
+ }
+
+ if len(protos) == 0 {
+ return nil, ErrNoProtocols
+ }
+
+ var wgKey *wgtypes.Key
+
+ // Loop over the protocols and set the correct headers and values
+ for _, p := range protos {
+ switch p {
+ case protocol.WireGuard:
+ gk, err := wgtypes.GeneratePrivateKey()
+ if err != nil {
+ return nil, err
+ }
+ wgKey = &gk
+ // Set the public key
+ pubkey := wgKey.PublicKey()
+ uv.Set("public_key", pubkey.String())
+ hdrs.Add("accept", "application/x-wireguard-profile")
+ hdrs.Add("accept", "application/x-wireguard+tcp-profile")
+ case protocol.OpenVPN:
+ hdrs.Add("accept", "application/x-openvpn-profile")
+ default:
+ return nil, ErrUnknownProtocol
+ }
+ }
+ // set prefer TCP
+ uv.Set("prefer_tcp", boolToYesNo(pTCP))
+
+ // Construct the parameters
+ params := &httpwrap.OptionalParams{Headers: hdrs, Body: uv}
+ h, body, err := a.authorizedRetry(ctx, http.MethodPost, "/connect", params)
+ if err != nil {
+ return nil, fmt.Errorf("failed API /connect call: %w", err)
+ }
+
+ // Parse expiry
+ expH := h.Get("expires")
+ if expH == "" {
+ return nil, errors.New("the server did not give an expires header")
+ }
+ expT, err := http.ParseTime(expH)
+ if err != nil {
+ return nil, fmt.Errorf("failed parsing expiry time: %w", err)
+ }
+
+ vpnCfg := string(body)
+ // Parse content type
+ contentH := h.Get("content-type")
+ proto, err := protocolFromCT(contentH)
+ if err != nil {
+ return nil, err
+ }
+
+ if proto == protocol.OpenVPN {
+ // ensure scripts are not ran by default by append script-security 0 to the config
+ vpnCfg += "\nscript-security 0"
+ return &ConnectData{
+ Configuration: vpnCfg,
+ Protocol: proto,
+ Expires: expT,
+ }, nil
+ }
+
+ vpnCfg, err = wireguard.Config(vpnCfg, wgKey)
+ if err != nil {
+ return nil, err
+ }
+ return &ConnectData{
+ Configuration: vpnCfg,
+ Protocol: proto,
+ Expires: expT,
+ }, nil
+}
+
+func getEndpoints(ctx context.Context, url string, tp http.RoundTripper) (*endpoints.Endpoints, error) {
+ uStr, err := httpwrap.JoinURLPath(url, "/.well-known/vpn-user-portal")
+ if err != nil {
+ return nil, err
+ }
+ httpC := httpwrap.NewClient(nil)
+ httpC.Client.Transport = tp
+ _, body, err := httpC.Get(ctx, uStr)
+ if err != nil {
+ return nil, fmt.Errorf("failed getting server endpoints with error: %w", err)
+ }
+
+ ep := endpoints.Endpoints{}
+ if err = json.Unmarshal(body, &ep); err != nil {
+ return nil, fmt.Errorf("failed getting server endpoints with error: %w", err)
+ }
+ err = ep.Validate()
+ if err != nil {
+ return nil, err
+ }
+ return &ep, nil
+}
+
+// OAuthLogger is defined here to update the internal logger
+// for the eduoauth library
+type OAuthLogger struct{}
+
+// Logf logs a message with parameters
+func (ol *OAuthLogger) Logf(msg string, params ...any) {
+ slog.Debug("OAuth log", "log", fmt.Sprintf(msg, params...))
+}
+
+// Log logs a message
+func (ol *OAuthLogger) Log(msg string) {
+ slog.Debug("OAuth log", "log", msg)
+}
+
+func init() {
+ eduoauth.UpdateLogger(&OAuthLogger{})
+}
diff --git a/internal/eduvpnapi/eduvpnapi_test.go b/internal/eduvpnapi/eduvpnapi_test.go
new file mode 100644
index 0000000..23f895b
--- /dev/null
+++ b/internal/eduvpnapi/eduvpnapi_test.go
@@ -0,0 +1,513 @@
+package eduvpnapi
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "net/url"
+ "reflect"
+ "regexp"
+ "slices"
+ "strings"
+ "testing"
+ "time"
+
+ "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/profiles"
+ "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap"
+ "codeberg.org/eduVPN/eduvpn-common/internal/test"
+ "codeberg.org/eduVPN/eduvpn-common/types/protocol"
+ "codeberg.org/eduVPN/eduvpn-common/types/server"
+ "codeberg.org/jwijenbergh/eduoauth-go/v2"
+)
+
+func tokenHandler(t *testing.T, gt []string) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ t.Fatalf("invalid HTTP method for token handler: %v", r.Method)
+ }
+ b, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Fatalf("failed reading token endpoint body: %v", err)
+ }
+ parsed, err := url.ParseQuery(string(b))
+ if err != nil {
+ t.Fatalf("failed parsing query body: %v", err)
+ }
+ grant := parsed.Get("grant_type")
+
+ if slices.Contains(gt, grant) {
+ _, err = w.Write([]byte(`
+{
+ "access_token": "validaccess",
+ "refresh_token": "validrefresh",
+ "expires_in": 3600
+}
+ `))
+ if err != nil {
+ t.Fatalf("failed writing in token handler: %v", err)
+ }
+ return
+ }
+ t.Fatalf("grant type: %v, not allowed", grant)
+ }
+}
+
+func checkAuthBearer(t *testing.T, r *http.Request) {
+ authh := r.Header.Get("Authorization")
+ if !strings.HasPrefix(authh, "Bearer ") {
+ t.Fatalf("API call is not given with an authorization Bearer header, got: %v", authh)
+ }
+}
+
+func connectHandler(t *testing.T, proto string, exp time.Time) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ t.Fatalf("invalid HTTP method for connect handler: %v", r.Method)
+ }
+ checkAuthBearer(t, r)
+ w.Header().Set("expires", exp.Format(http.TimeFormat))
+ w.Header().Set("content-type", fmt.Sprintf("application/x-%s-profile", proto))
+ b, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Fatalf("failed reading token endpoint body: %v", err)
+ }
+ parsed, err := url.ParseQuery(string(b))
+ if err != nil {
+ t.Fatalf("failed parsing query body: %v", err)
+ }
+ // the wireguard config we parse
+ var cfg string
+ if proto == "openvpn" {
+ cfg = "openvpnconfig"
+ } else {
+ if parsed.Get("public_key") == "" {
+ t.Fatalf("no public_key given")
+ }
+ if proto == "wireguard+tcp" {
+ ptcp := parsed.Get("prefer_tcp")
+ if ptcp != "yes" {
+ t.Fatalf("prefer TCP is not yes: %s", ptcp)
+ }
+ cfg = `
+[Interface]
+[Peer]
+ProxyEndpoint = https://proxyendpoint
+`
+ } else {
+ cfg = "[Interface]"
+ }
+ }
+ _, err = w.Write([]byte(cfg))
+ if err != nil {
+ t.Fatalf("failed writing /connect response: %v", err)
+ }
+ }
+}
+
+func disconnectHandler(t *testing.T) func(http.ResponseWriter, *http.Request) {
+ return func(_ http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ t.Fatalf("invalid HTTP method for disconnect handler: %v", r.Method)
+ }
+ checkAuthBearer(t, r)
+ }
+}
+
+type TestCallback struct {
+ t *testing.T
+}
+
+func (tc *TestCallback) TriggerAuth(_ context.Context, str string, _ bool) (string, error) {
+ go func() {
+ u, err := url.Parse(str)
+ if err != nil {
+ panic(err)
+ }
+ ru, err := url.Parse(u.Query().Get("redirect_uri"))
+ if err != nil {
+ panic(err)
+ }
+ oq := u.Query()
+ q := ru.Query()
+ q.Set("state", oq.Get("state"))
+ q.Set("code", "fakeauthcode")
+ ru.RawQuery = q.Encode()
+
+ c := http.Client{}
+ req, err := http.NewRequest("GET", ru.String(), nil)
+ if err != nil {
+ panic(err)
+ }
+ _, err = c.Do(req)
+ if err != nil {
+ panic(err)
+ }
+ }()
+ return "", nil
+}
+func (tc *TestCallback) AuthDone(string, server.Type) {}
+func (tc *TestCallback) TokensUpdated(string, server.Type, eduoauth.Token) {}
+
+// create a API struct with allowed grant types
+func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.HandlerPath) (*API, *test.Server) {
+ // Create a simple API client and check if the fields are created correctly
+ listen, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("failed to setup listener for test server: %v", err)
+ }
+
+ hps = append(hps, []test.HandlerPath{
+ {
+ Method: http.MethodGet,
+ Path: "/.well-known/vpn-user-portal",
+ Response: fmt.Sprintf(`
+{
+ "api": {
+ "http://eduvpn.org/api#3": {
+ "api_endpoint": "https://%[1]s/test-api-endpoint",
+ "authorization_endpoint": "https://%[1]s/test-authorization-endpoint",
+ "token_endpoint": "https://%[1]s/test-token-endpoint"
+ }
+ },
+ "v": "0.0.0"
+}
+`, listen.Addr().String()),
+ },
+ {
+ Path: "/test-token-endpoint",
+ ResponseHandler: tokenHandler(t, gt),
+ },
+ }...)
+ // start server
+ serv := test.NewServerWithHandles(hps, listen)
+ servc, err := serv.Client()
+ if err != nil {
+ t.Fatalf("failed to setup HTTP test server client: %v", servc)
+ }
+
+ sd := ServerData{
+ ID: "randomidentifier",
+ Type: server.TypeCustom,
+ BaseWK: serv.URL,
+ BaseAuthWK: serv.URL,
+ ProcessAuth: func(_ context.Context, in string) (string, error) {
+ return in, nil
+ },
+ DisableAuthorize: false,
+ transport: servc.Client.Transport,
+ }
+
+ tc := &TestCallback{t: t}
+
+ a, err := NewAPI(context.Background(), "testclient", sd, tc, tok)
+ if err != nil {
+ t.Fatalf("failed creating API: %v", err)
+ }
+ return a, serv
+}
+
+func TestNewAPI(t *testing.T) {
+ gts := []string{"refresh_token"}
+ tok := &eduoauth.Token{
+ Access: "expiredaccess",
+ Refresh: "expiredrefresh",
+ // tokens are expired, let's try authorizing
+ ExpiredTimestamp: time.Now(),
+ }
+ a, srv := createTestAPI(t, tok, gts, nil)
+ srv.Close()
+
+ // now the tokens should be the new access tokens
+ if a.oauth.Token().Access != "validaccess" {
+ t.Fatalf("access token is not valid access")
+ }
+ if a.oauth.Token().Refresh != "validrefresh" {
+ t.Fatalf("refresh token is not valid refresh")
+ }
+
+ gts = []string{"authorization_code"}
+ tok = &eduoauth.Token{
+ Access: "expiredaccess",
+ Refresh: "",
+ ExpiredTimestamp: time.Now(),
+ }
+ a, srv = createTestAPI(t, tok, gts, nil)
+ srv.Close()
+
+ // now the tokens should be the new access tokens
+ if a.oauth.Token().Access != "validaccess" {
+ t.Fatalf("access token is not valid access")
+ }
+ if a.oauth.Token().Refresh != "validrefresh" {
+ t.Fatalf("refresh token is not valid refresh")
+ }
+}
+
+func TestAPIInfo(t *testing.T) {
+ // auth should not be triggered
+ var gts []string
+ tok := &eduoauth.Token{
+ Access: "validaccess",
+ Refresh: "validrefresh",
+ ExpiredTimestamp: time.Now().Add(1 * time.Hour),
+ }
+ statErr := &httpwrap.StatusError{}
+ cases := []struct {
+ hp test.HandlerPath
+ info *profiles.Info
+ err any
+ }{
+ {
+ hp: test.HandlerPath{
+ Method: http.MethodGet,
+ Path: "/test-api-endpoint/info",
+ Response: `
+{
+ "info": {
+ "profile_list": [
+ {
+ "default_gateway": false,
+ "display_name": "test profile 1",
+ "profile_id": "test1",
+ "profile_priority": 3,
+ "vpn_proto_list": [
+ "openvpn",
+ "wireguard"
+ ]
+ }
+ ]
+ }
+}
+`,
+ },
+ info: &profiles.Info{
+ Info: profiles.ListInfo{
+ ProfileList: []profiles.Profile{
+ {
+ ID: "test1",
+ DisplayName: "test profile 1",
+ VPNProtoList: []string{"openvpn", "wireguard"},
+ Priority: 3,
+ DefaultGateway: false,
+ },
+ },
+ },
+ },
+ },
+ {
+ hp: test.HandlerPath{
+ Method: http.MethodGet,
+ Path: "/test-api-endpoint/info",
+ Response: `
+{
+ "info": {
+ "profile_list": [
+ {
+ "display_name": "test profile 2",
+ "profile_id": "test2",
+ "vpn_proto_list": [
+ "wireguard"
+ ]
+ }
+ ]
+ }
+}
+`,
+ },
+ info: &profiles.Info{
+ Info: profiles.ListInfo{
+ ProfileList: []profiles.Profile{
+ {
+ ID: "test2",
+ DisplayName: "test profile 2",
+ VPNProtoList: []string{"wireguard"},
+ DefaultGateway: false,
+ },
+ },
+ },
+ },
+ },
+ {
+ hp: test.HandlerPath{
+ Method: http.MethodGet,
+ Path: "/test-api-endpoint/info",
+ Response: "",
+ ResponseCode: 404,
+ },
+ info: nil,
+ err: &statErr,
+ },
+ }
+
+ for _, c := range cases {
+ a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp})
+ defer srv.Close()
+ gprfs, err := a.Info(context.Background())
+ // got error but the want error is nil
+ if err != nil {
+ if c.err == nil {
+ t.Fatalf("failed profiles info: %v but want no error", err)
+ }
+
+ if !errors.As(err, c.err) {
+ t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err.Error())
+ }
+ } else if c.err != nil {
+ t.Fatalf("got no error but want error: %T", c.err)
+ }
+
+ if !reflect.DeepEqual(gprfs, c.info) {
+ t.Fatalf("got info: %v, not equal to want: %v", gprfs, c.info)
+ }
+ }
+}
+
+func TestAPIConnect(t *testing.T) {
+ // auth should not be triggered
+ var gts []string
+ tok := &eduoauth.Token{
+ Access: "validaccess",
+ Refresh: "validrefresh",
+ ExpiredTimestamp: time.Now().Add(1 * time.Hour),
+ }
+ cases := []struct {
+ hp test.HandlerPath
+ cd *ConnectData
+ prof profiles.Profile
+ protos []protocol.Protocol
+ ptcp bool
+ err error
+ }{
+ {
+ hp: test.HandlerPath{
+ Method: http.MethodPost,
+ Path: "/test-api-endpoint/connect",
+ Response: ``,
+ },
+ cd: nil,
+ err: ErrNoProtocols,
+ },
+ {
+ hp: test.HandlerPath{
+ Method: http.MethodPost,
+ Path: "/test-api-endpoint/connect",
+ Response: ``,
+ },
+ cd: nil,
+ protos: []protocol.Protocol{protocol.Unknown},
+ err: ErrUnknownProtocol,
+ },
+ {
+ hp: test.HandlerPath{
+ Method: http.MethodPost,
+ Path: "/test-api-endpoint/connect",
+ Response: ``,
+ },
+ cd: nil,
+ protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard, protocol.Unknown},
+ err: ErrUnknownProtocol,
+ },
+ {
+ hp: test.HandlerPath{
+ Method: http.MethodPost,
+ Path: "/test-api-endpoint/connect",
+ ResponseHandler: connectHandler(t, "openvpn", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)),
+ },
+ cd: &ConnectData{
+ Configuration: "openvpnconfig\nscript-security 0",
+ Protocol: protocol.OpenVPN,
+ Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC),
+ },
+ protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard},
+ err: nil,
+ },
+ {
+ hp: test.HandlerPath{
+ Method: http.MethodPost,
+ Path: "/test-api-endpoint/connect",
+ ResponseHandler: connectHandler(t, "wireguard", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)),
+ },
+ cd: &ConnectData{
+ Configuration: `\[Interface\]
+PrivateKey = .*`,
+ Protocol: protocol.WireGuard,
+ Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC),
+ },
+ protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard},
+ err: nil,
+ },
+ {
+ hp: test.HandlerPath{
+ Method: http.MethodPost,
+ Path: "/test-api-endpoint/connect",
+ ResponseHandler: connectHandler(t, "wireguard+tcp", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)),
+ },
+ cd: &ConnectData{
+ Configuration: `\[Interface\]
+PrivateKey = .*`,
+ Protocol: protocol.WireGuardProxy,
+ Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC),
+ },
+ ptcp: true,
+ protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard},
+ err: nil,
+ },
+ }
+
+ for _, c := range cases {
+ a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp})
+ defer srv.Close()
+ gcd, err := a.Connect(context.Background(), c.prof, c.protos, c.ptcp)
+ // got error but the want error is nil
+ if err != nil {
+ if c.err == nil {
+ t.Fatalf("failed connect: %v but want no error", err)
+ }
+
+ if !errors.Is(err, c.err) {
+ t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err)
+ }
+ } else if c.err != nil {
+ t.Fatalf("got no error but want error: %T", c.err)
+ }
+
+ if gcd != nil && c.cd != nil {
+ m, err := regexp.MatchString(c.cd.Configuration, gcd.Configuration)
+ if err != nil {
+ t.Fatalf("failed matching regexp: %v", err)
+ }
+ if !m {
+ t.Fatalf("regex:\n%s\ndoes not match config:\n%s", c.cd.Configuration, gcd.Configuration)
+ }
+ // we have already checked the config using a regex
+ c.cd.Configuration = gcd.Configuration
+
+ }
+ if !reflect.DeepEqual(gcd, c.cd) {
+ t.Fatalf("got connect data: %v, not equal to want: %v", gcd, c.cd)
+ }
+ }
+}
+
+func TestDisconnect(t *testing.T) {
+ var gts []string
+ tok := &eduoauth.Token{
+ Access: "validaccess",
+ Refresh: "validrefresh",
+ ExpiredTimestamp: time.Now().Add(1 * time.Hour),
+ }
+ a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{
+ {
+ Path: "/test-api-endpoint/disconnect",
+ ResponseHandler: disconnectHandler(t),
+ },
+ })
+ defer srv.Close()
+ err := a.Disconnect(context.Background())
+ if err != nil {
+ t.Fatalf("failed /disconnect: %v", err)
+ }
+}
diff --git a/internal/eduvpnapi/endpoints/endpoints.go b/internal/eduvpnapi/endpoints/endpoints.go
new file mode 100644
index 0000000..c98d2c7
--- /dev/null
+++ b/internal/eduvpnapi/endpoints/endpoints.go
@@ -0,0 +1,62 @@
+// Package endpoints defines a wrapper around the various
+// endpoints returned by an eduVPN server in well-known
+package endpoints
+
+import (
+ "fmt"
+ "net/url"
+)
+
+// List is the list of endpoints as returned by the eduVPN server
+type List struct {
+ // API is the API endpoint which we use for calls such as /info, /connect, ...
+ API string `json:"api_endpoint"`
+ // Authorization is the authorization endpoint for OAuth
+ Authorization string `json:"authorization_endpoint"`
+ // Token is the token endpoint for OAuth
+ Token string `json:"token_endpoint"`
+}
+
+// Versions is the endpoints separated by API version
+type Versions struct {
+ // V2 is the legacy V2 API, this is not used
+ V2 List `json:"http://eduvpn.org/api#2"`
+ // V3 is the newest API, which we use
+ V3 List `json:"http://eduvpn.org/api#3"`
+}
+
+// Endpoints defines the json format for /.well-known/vpn-user-portal".
+type Endpoints struct {
+ // API defines the API endpoints, split by version
+ API Versions `json:"api"`
+ // V is the version string for the server
+ V string `json:"v"`
+}
+
+// Validate validates the endpoints by parsing them and checking the scheme is HTTP
+// An error is returned if they are not valid
+func (e Endpoints) Validate() error {
+ v3 := e.API.V3
+ pAPI, err := url.Parse(v3.API)
+ if err != nil {
+ return fmt.Errorf("failed to parse API endpoint: %w", err)
+ }
+ pAuth, err := url.Parse(v3.Authorization)
+ if err != nil {
+ return fmt.Errorf("failed to parse API authorization endpoint: %w", err)
+ }
+ pToken, err := url.Parse(v3.Token)
+ if err != nil {
+ return fmt.Errorf("failed to parse API token endpoint: %w", err)
+ }
+ if pAPI.Scheme != "https" {
+ return fmt.Errorf("API Scheme: '%s', is not equal to HTTPS", pAPI.Scheme)
+ }
+ if pAPI.Scheme != pAuth.Scheme {
+ return fmt.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme)
+ }
+ if pAPI.Scheme != pToken.Scheme {
+ return fmt.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme)
+ }
+ return nil
+}
diff --git a/internal/eduvpnapi/profiles/profiles.go b/internal/eduvpnapi/profiles/profiles.go
new file mode 100644
index 0000000..77109f1
--- /dev/null
+++ b/internal/eduvpnapi/profiles/profiles.go
@@ -0,0 +1,119 @@
+// Package profiles defines a wrapper around the various profiles
+// returned by the /info endpoint
+package profiles
+
+import (
+ "codeberg.org/eduVPN/eduvpn-common/types/protocol"
+ "codeberg.org/eduVPN/eduvpn-common/types/server"
+)
+
+// Profile is the information for a profile
+type Profile struct {
+ // ID is the identifier of the profile
+ // Used to select a profile
+ ID string `json:"profile_id"`
+ // DisplayName defines the UI friendly name for the profile
+ DisplayName string `json:"display_name"`
+ // VPNProtoList defines the list of VPN protocols
+ // E.g. wireguard, openvpn
+ VPNProtoList []string `json:"vpn_proto_list"`
+ // VPNProtoTransportList defines the list of VPN protocols including their transport values
+ // E.g. wireguard+udp, openvpn+tcp
+ VPNProtoTransportList []string `json:"vpn_proto_transport_list"`
+ // DefaultGateway specifies whether or not this profile is a default gateway profile
+ DefaultGateway bool `json:"default_gateway"`
+ // DNSSearchDomains specifies the list of dns search domains
+ // This is provided for a Linux client issue
+ // See: https://github.com/eduvpn/python-eduvpn-client/issues/550
+ DNSSearchDomains []string `json:"dns_search_domain_list"`
+ // Priority is the priority of the profile for sorting in the UI
+ // the higher the priority, the higher it should be in the list
+ Priority int `json:"profile_priority"`
+}
+
+// ListInfo is the struct that has the profile list
+type ListInfo struct {
+ ProfileList []Profile `json:"profile_list"`
+}
+
+// Info is the top-level struct for the info endpoint
+type Info struct {
+ Info ListInfo `json:"info"`
+}
+
+// Len returns the length of the profile list
+func (i Info) Len() int {
+ return len(i.Info.ProfileList)
+}
+
+// Get returns a profile with id `id`, it returns nil if it is not found
+func (i Info) Get(id string) *Profile {
+ for _, p := range i.Info.ProfileList {
+ if p.ID == id {
+ return &p
+ }
+ }
+ return nil
+}
+
+// MustIndex gets a profile by index
+// This index must be in the bounds
+func (i Info) MustIndex(n int) Profile {
+ return i.Info.ProfileList[n]
+}
+
+func hasProtocol(protos []string, proto protocol.Protocol) bool {
+ for _, p := range protos {
+ if protocol.New(p) == proto {
+ return true
+ }
+ }
+ return false
+}
+
+// ShouldFailover returns whether or not this VPN profile should start a failover procedure
+// This is true when the profile supports a TCP connection
+// If we cannot determine whether it supports a TCP connection
+// (because the server doesn't provide the VPN transport list function yet),
+// we will just check if it supports OpenVPN
+func (p *Profile) ShouldFailover() bool {
+ // old servers don't support it, only failover in case OpenVPN is supported
+ if len(p.VPNProtoTransportList) == 0 {
+ // this checks VPNProtoList
+ return p.HasOpenVPN()
+ }
+ for _, c := range p.VPNProtoTransportList {
+ if c == "wireguard+tcp" {
+ return true
+ }
+ if c == "openvpn+tcp" {
+ return true
+ }
+ }
+ return false
+}
+
+// HasOpenVPN returns whether or not the profile has OpenVPN support
+func (p *Profile) HasOpenVPN() bool {
+ return hasProtocol(p.VPNProtoList, protocol.OpenVPN)
+}
+
+// HasWireGuard returns whether or not the profile has WireGuard support
+func (p *Profile) HasWireGuard() bool {
+ return hasProtocol(p.VPNProtoList, protocol.WireGuard)
+}
+
+// Public gets the server list as a structure that we return to clients
+func (i Info) Public() server.Profiles {
+ m := make(map[string]server.Profile)
+ for _, p := range i.Info.ProfileList {
+ m[p.ID] = server.Profile{
+ DisplayName: map[string]string{
+ "en": p.DisplayName,
+ },
+ DefaultGateway: p.DefaultGateway,
+ Priority: p.Priority,
+ }
+ }
+ return server.Profiles{Map: m}
+}
diff --git a/internal/eduvpnapi/redirect.go b/internal/eduvpnapi/redirect.go
new file mode 100644
index 0000000..7af31fb
--- /dev/null
+++ b/internal/eduvpnapi/redirect.go
@@ -0,0 +1,28 @@
+package eduvpnapi
+
+// customRedirects supplies redirect values that should be handled by the app itself
+// here we hardcode the redirect values that we should use in the OAuth requests
+// these values were taken from https://codeberg.org/eduVPN/vpn-user-portal/src/branch/v3/src/OAuth/VpnClientDb.php
+var customRedirects = map[string]string{
+ "org.letsconnect-vpn.app.macos": "org.letsconnect-vpn.app.macos:/api/callback",
+ "org.letsconnect-vpn.app.ios": "org.letsconnect-vpn.app.ios:/api/callback",
+ "org.letsconnect-vpn.app.android": "org.letsconnect-vpn.app.android:/api/callback",
+ "org.eduvpn.app.macos": "org.eduvpn.app.macos:/api/callback",
+ "org.eduvpn.app.ios": "org.eduvpn.app.ios:/api/callback",
+ "org.eduvpn.app.android": "org.eduvpn.app.android:/api/callback",
+ "org.govvpn.app.macos": "org.govvpn.app.macos:/api/callback",
+ "org.govvpn.app.ios": "org.govvpn.app.ios:/api/callback",
+ "org.govvpn.app.android": "org.govvpn.app.android:/api/callback",
+}
+
+// customRedirect returns the custom redirect string for the clientID `cid`
+// Empty string if none is defined or one is defined but is empty.
+// In both empty string cases, eduvpn-common handles the redirects as 127.0.0.1 local server redirects
+// If a non-empty string is returned, the redirect should be handled by the client and we only use the redirect URI value in our OAuth requests
+func customRedirect(cid string) string {
+ v, ok := customRedirects[cid]
+ if !ok {
+ return ""
+ }
+ return v
+}