From a30ef6b27e578a4cf0a674b24f5b52b4c1516c63 Mon Sep 17 00:00:00 2001 From: Jeroen Wijenbergh Date: Thu, 12 Feb 2026 12:34:08 +0100 Subject: All: Rename packages that sound useless or clash with std --- internal/eduvpnapi/cache.go | 67 ++++ internal/eduvpnapi/eduvpnapi.go | 395 +++++++++++++++++++++++ internal/eduvpnapi/eduvpnapi_test.go | 513 ++++++++++++++++++++++++++++++ internal/eduvpnapi/endpoints/endpoints.go | 62 ++++ internal/eduvpnapi/profiles/profiles.go | 119 +++++++ internal/eduvpnapi/redirect.go | 28 ++ 6 files changed, 1184 insertions(+) create mode 100644 internal/eduvpnapi/cache.go create mode 100644 internal/eduvpnapi/eduvpnapi.go create mode 100644 internal/eduvpnapi/eduvpnapi_test.go create mode 100644 internal/eduvpnapi/endpoints/endpoints.go create mode 100644 internal/eduvpnapi/profiles/profiles.go create mode 100644 internal/eduvpnapi/redirect.go (limited to 'internal/eduvpnapi') diff --git a/internal/eduvpnapi/cache.go b/internal/eduvpnapi/cache.go new file mode 100644 index 0000000..d3e2b77 --- /dev/null +++ b/internal/eduvpnapi/cache.go @@ -0,0 +1,67 @@ +package eduvpnapi + +import ( + "context" + "net/http" + "sync" + "time" + + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/endpoints" +) + +// EndpointCache is a struct that caches well-known API endpoints +type EndpointCache struct { + lastUpdate map[string]time.Time + lastEP map[string]*endpoints.Endpoints + mu sync.Mutex +} + +// Get returns a cached or fresh endpoint cache copy +func (ec *EndpointCache) Get(ctx context.Context, wk string, transport http.RoundTripper) (*endpoints.Endpoints, error) { + ec.mu.Lock() + defer ec.mu.Unlock() + + // get the last update time + lu := time.Time{} + if v, ok := ec.lastUpdate[wk]; ok { + lu = v + } + + // if not 10 minutes have passed, return cached copy + if !lu.IsZero() && !time.Now().After(lu.Add(10*time.Minute)) { + v, ok := ec.lastEP[wk] + if ok { + return v, nil + } + } + + // get fresh API endpoints + ep, err := getEndpoints(ctx, wk, transport) + if err != nil { + return nil, err + } + + // update endpoints + ec.lastUpdate[wk] = time.Now() + ec.lastEP[wk] = ep + + return ep, nil +} + +var ( + epCache *EndpointCache + epCacheOnce sync.Once +) + +// GetEndpointCache returns the global singleton endpoint cache +// or creates one if it does not exist +func GetEndpointCache() *EndpointCache { + epCacheOnce.Do(func() { + epCache = &EndpointCache{ + lastUpdate: make(map[string]time.Time), + lastEP: make(map[string]*endpoints.Endpoints), + } + }) + + return epCache +} diff --git a/internal/eduvpnapi/eduvpnapi.go b/internal/eduvpnapi/eduvpnapi.go new file mode 100644 index 0000000..62fe0d1 --- /dev/null +++ b/internal/eduvpnapi/eduvpnapi.go @@ -0,0 +1,395 @@ +// Package eduvpnapi implements version 3 of the eduVPN api: https://docs.eduvpn.org/server/v3/api.html +package eduvpnapi + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "net/url" + "time" + + "codeberg.org/jwijenbergh/eduoauth-go/v2" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/endpoints" + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/profiles" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" + "codeberg.org/eduVPN/eduvpn-common/internal/wireguard" + "codeberg.org/eduVPN/eduvpn-common/types/protocol" + "codeberg.org/eduVPN/eduvpn-common/types/server" +) + +// Callbacks is the API callback interface +// It is used to trigger authorization and forward token updates +type Callbacks interface { + // TriggerAuth is called when authorization should be triggered + TriggerAuth(context.Context, string, bool) (string, error) + // AuthDone is called when authorization has just completed + AuthDone(string, server.Type) + // TokensUpdates is called when tokens are updated + TokensUpdated(string, server.Type, eduoauth.Token) +} + +// ServerData is the data for a server that is passed to the API struct +type ServerData struct { + // ID is the identifier for the server + ID string + // Type is the type of server + Type server.Type + // BaseWK is the base well-known endpoint + BaseWK string + // BaseAuthWK is the base well-known endpoint for authorization. This is only different in case of secure internet + BaseAuthWK string + // ProcessAuth processes the OAuth authorization + ProcessAuth func(context.Context, string) (string, error) + // DisableAuthorize indicates whether or not new authorization requests should be disabled + DisableAuthorize bool + // transport is the HTTP transport, only used for testing currently + transport http.RoundTripper +} + +// Transport returns the transport to be used for the server +// By default it uses the transport from internal/httpwrap DefaultTransport +func (s *ServerData) Transport() http.RoundTripper { + if s.transport == nil { + return httpwrap.DefaultTransport + } + return s.transport +} + +// API is the top-level struct that each method is defined on +type API struct { + cb Callbacks + // oauth is the oauth object + oauth *eduoauth.OAuth + // Data is the server data + Data ServerData +} + +// NewAPI creates a new API object by creating an OAuth object +func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, tokens *eduoauth.Token) (*API, error) { + cr := customRedirect(clientID) + // Construct OAuth + + transp := sd.Transport() + post := true + // we do not support non-loopback clients with response_mode form_post + if cr != "" { + post = false + } + o := eduoauth.OAuth{ + ClientID: clientID, + EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) { + ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, transp) + if err != nil { + return nil, err + } + return &eduoauth.EndpointResponse{ + AuthorizationURL: ep.API.V3.Authorization, + TokenURL: ep.API.V3.Token, + }, nil + }, + CustomRedirect: cr, + FormPost: post, + RedirectPath: "/callback", + TokensUpdated: func(tok eduoauth.Token) { + cb.TokensUpdated(sd.ID, sd.Type, tok) + }, + Transport: transp, + UserAgent: httpwrap.UserAgent, + } + + if tokens != nil { + o.UpdateTokens(*tokens) + } + + api := &API{ + cb: cb, + oauth: &o, + Data: sd, + } + err := api.authorize(ctx) + if err != nil { + return nil, err + } + return api, nil +} + +// ErrAuthorizeDisabled is returned when authorization is disabled but is needed to complete +var ErrAuthorizeDisabled = errors.New("cannot authorize as re-authorization is disabled") + +func (a *API) authorize(ctx context.Context) (err error) { + _, err = a.oauth.AccessToken(ctx) + // already authorized + if err == nil { + return nil + } + + // otherwise check if invalid tokens, + // if not then something else is wrong with the API + // return an error + tErr := &eduoauth.TokensInvalidError{} + if !errors.As(err, &tErr) { + return err + } + + if a.Data.DisableAuthorize { + return ErrAuthorizeDisabled + } + + defer func() { + if err == nil { + a.cb.AuthDone(a.Data.ID, a.Data.Type) + } + }() + + scope := "config" + url, err := a.oauth.AuthURL(ctx, scope) + if err != nil { + return err + } + if a.Data.ProcessAuth != nil { + url, err = a.Data.ProcessAuth(ctx, url) + if err != nil { + return err + } + } + // We expect an uri if custom redirect is non empty + uri, err := a.cb.TriggerAuth(ctx, url, a.oauth.CustomRedirect != "") + if err != nil { + return err + } + // The uri is only given here if a custom redirect is done + err = a.oauth.Exchange(ctx, uri) + if err != nil { + return err + } + return nil +} + +func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpwrap.OptionalParams) (http.Header, []byte, error) { + ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport()) + if err != nil { + return nil, nil, err + } + u := ep.API.V3.API + endpoint + + // TODO: Cache HTTP client? + httpC := httpwrap.NewClient(a.oauth.NewHTTPClient()) + return httpC.Do(ctx, method, u, opts) +} + +func (a *API) authorizedRetry(ctx context.Context, method string, endpoint string, opts *httpwrap.OptionalParams) (http.Header, []byte, error) { + h, body, err := a.authorized(ctx, method, endpoint, opts) + if err == nil { + return h, body, nil + } + + statErr := &httpwrap.StatusError{} + // Only retry authorized if we get an HTTP 401 + // TODO: Can the OAuth client handle this instead? + if errors.As(err, &statErr) && statErr.Status == 401 { + slog.Debug("Got a HTTP 401. Marking tokens as expired...", "HTTP method", method, "endpoint", endpoint) + // Mark the token as expired and retry, so we trigger the refresh flow + a.oauth.SetTokenExpired() + h, body, err = a.authorized(ctx, method, endpoint, opts) + } + // Tokens is invalid we need to renew and authorize again + tErr := &eduoauth.TokensInvalidError{} + if err != nil && errors.As(err, &tErr) { + // Mark the token as invalid and retry, so we trigger the authorization flow + a.oauth.SetTokenRenew() + slog.Debug("The tokens were invalid, trying again...") + if autherr := a.authorize(ctx); autherr != nil { + return nil, nil, autherr + } + return a.authorized(ctx, method, endpoint, opts) + } + return h, body, err +} + +// Disconnect disconnects a client from the server by sending a /disconnect API call +// This cleans up resources such as WireGuard IP allocation +func (a *API) Disconnect(ctx context.Context) error { + _, _, err := a.authorized(ctx, http.MethodPost, "/disconnect", &httpwrap.OptionalParams{Timeout: 5 * time.Second}) + return err +} + +// Info does the /info API call +func (a *API) Info(ctx context.Context) (*profiles.Info, error) { + _, body, err := a.authorizedRetry(ctx, http.MethodGet, "/info", nil) + if err != nil { + return nil, fmt.Errorf("failed API /info: %w", err) + } + p := profiles.Info{} + if err = json.Unmarshal(body, &p); err != nil { + return nil, fmt.Errorf("failed API /info: %w", err) + } + return &p, nil +} + +// ConnectData is the data that is returned when the /connect call completes without error +type ConnectData struct { + // Configuration is the VPN configuration + Configuration string + // Protocol tells us what protocol it is, OpenVPN or WireGuard (proxied or not) + Protocol protocol.Protocol + // Expires tells us when this configuration expires + Expires time.Time +} + +// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 +func boolToYesNo(preferTCP bool) string { + if preferTCP { + return "yes" + } + return "no" +} + +func protocolFromCT(ct string) (protocol.Protocol, error) { + switch ct { + case "application/x-wireguard-profile": + return protocol.WireGuard, nil + case "application/x-wireguard+tcp-profile": + return protocol.WireGuardProxy, nil + case "application/x-openvpn-profile": + return protocol.OpenVPN, nil + } + return protocol.Unknown, fmt.Errorf("invalid content type: %s", ct) +} + +// ErrNoProtocols is returned when a connect call is given with an empty protocol slice +var ErrNoProtocols = errors.New("no protocols supplied") + +// ErrUnknownProtocol is returned when the client in a connect gives an unknown protocol +var ErrUnknownProtocol = errors.New("unknown protocol supplied") + +// Connect sends a /connect to an eduVPN server +// `ctx` is the context used for cancellation +// protos is the list of protocols supported and wanted by the client +func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []protocol.Protocol, pTCP bool) (*ConnectData, error) { + hdrs := http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + } + uv := url.Values{ + "profile_id": {prof.ID}, + } + + if len(protos) == 0 { + return nil, ErrNoProtocols + } + + var wgKey *wgtypes.Key + + // Loop over the protocols and set the correct headers and values + for _, p := range protos { + switch p { + case protocol.WireGuard: + gk, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + wgKey = &gk + // Set the public key + pubkey := wgKey.PublicKey() + uv.Set("public_key", pubkey.String()) + hdrs.Add("accept", "application/x-wireguard-profile") + hdrs.Add("accept", "application/x-wireguard+tcp-profile") + case protocol.OpenVPN: + hdrs.Add("accept", "application/x-openvpn-profile") + default: + return nil, ErrUnknownProtocol + } + } + // set prefer TCP + uv.Set("prefer_tcp", boolToYesNo(pTCP)) + + // Construct the parameters + params := &httpwrap.OptionalParams{Headers: hdrs, Body: uv} + h, body, err := a.authorizedRetry(ctx, http.MethodPost, "/connect", params) + if err != nil { + return nil, fmt.Errorf("failed API /connect call: %w", err) + } + + // Parse expiry + expH := h.Get("expires") + if expH == "" { + return nil, errors.New("the server did not give an expires header") + } + expT, err := http.ParseTime(expH) + if err != nil { + return nil, fmt.Errorf("failed parsing expiry time: %w", err) + } + + vpnCfg := string(body) + // Parse content type + contentH := h.Get("content-type") + proto, err := protocolFromCT(contentH) + if err != nil { + return nil, err + } + + if proto == protocol.OpenVPN { + // ensure scripts are not ran by default by append script-security 0 to the config + vpnCfg += "\nscript-security 0" + return &ConnectData{ + Configuration: vpnCfg, + Protocol: proto, + Expires: expT, + }, nil + } + + vpnCfg, err = wireguard.Config(vpnCfg, wgKey) + if err != nil { + return nil, err + } + return &ConnectData{ + Configuration: vpnCfg, + Protocol: proto, + Expires: expT, + }, nil +} + +func getEndpoints(ctx context.Context, url string, tp http.RoundTripper) (*endpoints.Endpoints, error) { + uStr, err := httpwrap.JoinURLPath(url, "/.well-known/vpn-user-portal") + if err != nil { + return nil, err + } + httpC := httpwrap.NewClient(nil) + httpC.Client.Transport = tp + _, body, err := httpC.Get(ctx, uStr) + if err != nil { + return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) + } + + ep := endpoints.Endpoints{} + if err = json.Unmarshal(body, &ep); err != nil { + return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) + } + err = ep.Validate() + if err != nil { + return nil, err + } + return &ep, nil +} + +// OAuthLogger is defined here to update the internal logger +// for the eduoauth library +type OAuthLogger struct{} + +// Logf logs a message with parameters +func (ol *OAuthLogger) Logf(msg string, params ...any) { + slog.Debug("OAuth log", "log", fmt.Sprintf(msg, params...)) +} + +// Log logs a message +func (ol *OAuthLogger) Log(msg string) { + slog.Debug("OAuth log", "log", msg) +} + +func init() { + eduoauth.UpdateLogger(&OAuthLogger{}) +} diff --git a/internal/eduvpnapi/eduvpnapi_test.go b/internal/eduvpnapi/eduvpnapi_test.go new file mode 100644 index 0000000..23f895b --- /dev/null +++ b/internal/eduvpnapi/eduvpnapi_test.go @@ -0,0 +1,513 @@ +package eduvpnapi + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "reflect" + "regexp" + "slices" + "strings" + "testing" + "time" + + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/profiles" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" + "codeberg.org/eduVPN/eduvpn-common/internal/test" + "codeberg.org/eduVPN/eduvpn-common/types/protocol" + "codeberg.org/eduVPN/eduvpn-common/types/server" + "codeberg.org/jwijenbergh/eduoauth-go/v2" +) + +func tokenHandler(t *testing.T, gt []string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for token handler: %v", r.Method) + } + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed reading token endpoint body: %v", err) + } + parsed, err := url.ParseQuery(string(b)) + if err != nil { + t.Fatalf("failed parsing query body: %v", err) + } + grant := parsed.Get("grant_type") + + if slices.Contains(gt, grant) { + _, err = w.Write([]byte(` +{ + "access_token": "validaccess", + "refresh_token": "validrefresh", + "expires_in": 3600 +} + `)) + if err != nil { + t.Fatalf("failed writing in token handler: %v", err) + } + return + } + t.Fatalf("grant type: %v, not allowed", grant) + } +} + +func checkAuthBearer(t *testing.T, r *http.Request) { + authh := r.Header.Get("Authorization") + if !strings.HasPrefix(authh, "Bearer ") { + t.Fatalf("API call is not given with an authorization Bearer header, got: %v", authh) + } +} + +func connectHandler(t *testing.T, proto string, exp time.Time) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for connect handler: %v", r.Method) + } + checkAuthBearer(t, r) + w.Header().Set("expires", exp.Format(http.TimeFormat)) + w.Header().Set("content-type", fmt.Sprintf("application/x-%s-profile", proto)) + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed reading token endpoint body: %v", err) + } + parsed, err := url.ParseQuery(string(b)) + if err != nil { + t.Fatalf("failed parsing query body: %v", err) + } + // the wireguard config we parse + var cfg string + if proto == "openvpn" { + cfg = "openvpnconfig" + } else { + if parsed.Get("public_key") == "" { + t.Fatalf("no public_key given") + } + if proto == "wireguard+tcp" { + ptcp := parsed.Get("prefer_tcp") + if ptcp != "yes" { + t.Fatalf("prefer TCP is not yes: %s", ptcp) + } + cfg = ` +[Interface] +[Peer] +ProxyEndpoint = https://proxyendpoint +` + } else { + cfg = "[Interface]" + } + } + _, err = w.Write([]byte(cfg)) + if err != nil { + t.Fatalf("failed writing /connect response: %v", err) + } + } +} + +func disconnectHandler(t *testing.T) func(http.ResponseWriter, *http.Request) { + return func(_ http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for disconnect handler: %v", r.Method) + } + checkAuthBearer(t, r) + } +} + +type TestCallback struct { + t *testing.T +} + +func (tc *TestCallback) TriggerAuth(_ context.Context, str string, _ bool) (string, error) { + go func() { + u, err := url.Parse(str) + if err != nil { + panic(err) + } + ru, err := url.Parse(u.Query().Get("redirect_uri")) + if err != nil { + panic(err) + } + oq := u.Query() + q := ru.Query() + q.Set("state", oq.Get("state")) + q.Set("code", "fakeauthcode") + ru.RawQuery = q.Encode() + + c := http.Client{} + req, err := http.NewRequest("GET", ru.String(), nil) + if err != nil { + panic(err) + } + _, err = c.Do(req) + if err != nil { + panic(err) + } + }() + return "", nil +} +func (tc *TestCallback) AuthDone(string, server.Type) {} +func (tc *TestCallback) TokensUpdated(string, server.Type, eduoauth.Token) {} + +// create a API struct with allowed grant types +func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.HandlerPath) (*API, *test.Server) { + // Create a simple API client and check if the fields are created correctly + listen, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to setup listener for test server: %v", err) + } + + hps = append(hps, []test.HandlerPath{ + { + Method: http.MethodGet, + Path: "/.well-known/vpn-user-portal", + Response: fmt.Sprintf(` +{ + "api": { + "http://eduvpn.org/api#3": { + "api_endpoint": "https://%[1]s/test-api-endpoint", + "authorization_endpoint": "https://%[1]s/test-authorization-endpoint", + "token_endpoint": "https://%[1]s/test-token-endpoint" + } + }, + "v": "0.0.0" +} +`, listen.Addr().String()), + }, + { + Path: "/test-token-endpoint", + ResponseHandler: tokenHandler(t, gt), + }, + }...) + // start server + serv := test.NewServerWithHandles(hps, listen) + servc, err := serv.Client() + if err != nil { + t.Fatalf("failed to setup HTTP test server client: %v", servc) + } + + sd := ServerData{ + ID: "randomidentifier", + Type: server.TypeCustom, + BaseWK: serv.URL, + BaseAuthWK: serv.URL, + ProcessAuth: func(_ context.Context, in string) (string, error) { + return in, nil + }, + DisableAuthorize: false, + transport: servc.Client.Transport, + } + + tc := &TestCallback{t: t} + + a, err := NewAPI(context.Background(), "testclient", sd, tc, tok) + if err != nil { + t.Fatalf("failed creating API: %v", err) + } + return a, serv +} + +func TestNewAPI(t *testing.T) { + gts := []string{"refresh_token"} + tok := &eduoauth.Token{ + Access: "expiredaccess", + Refresh: "expiredrefresh", + // tokens are expired, let's try authorizing + ExpiredTimestamp: time.Now(), + } + a, srv := createTestAPI(t, tok, gts, nil) + srv.Close() + + // now the tokens should be the new access tokens + if a.oauth.Token().Access != "validaccess" { + t.Fatalf("access token is not valid access") + } + if a.oauth.Token().Refresh != "validrefresh" { + t.Fatalf("refresh token is not valid refresh") + } + + gts = []string{"authorization_code"} + tok = &eduoauth.Token{ + Access: "expiredaccess", + Refresh: "", + ExpiredTimestamp: time.Now(), + } + a, srv = createTestAPI(t, tok, gts, nil) + srv.Close() + + // now the tokens should be the new access tokens + if a.oauth.Token().Access != "validaccess" { + t.Fatalf("access token is not valid access") + } + if a.oauth.Token().Refresh != "validrefresh" { + t.Fatalf("refresh token is not valid refresh") + } +} + +func TestAPIInfo(t *testing.T) { + // auth should not be triggered + var gts []string + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + statErr := &httpwrap.StatusError{} + cases := []struct { + hp test.HandlerPath + info *profiles.Info + err any + }{ + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: ` +{ + "info": { + "profile_list": [ + { + "default_gateway": false, + "display_name": "test profile 1", + "profile_id": "test1", + "profile_priority": 3, + "vpn_proto_list": [ + "openvpn", + "wireguard" + ] + } + ] + } +} +`, + }, + info: &profiles.Info{ + Info: profiles.ListInfo{ + ProfileList: []profiles.Profile{ + { + ID: "test1", + DisplayName: "test profile 1", + VPNProtoList: []string{"openvpn", "wireguard"}, + Priority: 3, + DefaultGateway: false, + }, + }, + }, + }, + }, + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: ` +{ + "info": { + "profile_list": [ + { + "display_name": "test profile 2", + "profile_id": "test2", + "vpn_proto_list": [ + "wireguard" + ] + } + ] + } +} +`, + }, + info: &profiles.Info{ + Info: profiles.ListInfo{ + ProfileList: []profiles.Profile{ + { + ID: "test2", + DisplayName: "test profile 2", + VPNProtoList: []string{"wireguard"}, + DefaultGateway: false, + }, + }, + }, + }, + }, + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: "", + ResponseCode: 404, + }, + info: nil, + err: &statErr, + }, + } + + for _, c := range cases { + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp}) + defer srv.Close() + gprfs, err := a.Info(context.Background()) + // got error but the want error is nil + if err != nil { + if c.err == nil { + t.Fatalf("failed profiles info: %v but want no error", err) + } + + if !errors.As(err, c.err) { + t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err.Error()) + } + } else if c.err != nil { + t.Fatalf("got no error but want error: %T", c.err) + } + + if !reflect.DeepEqual(gprfs, c.info) { + t.Fatalf("got info: %v, not equal to want: %v", gprfs, c.info) + } + } +} + +func TestAPIConnect(t *testing.T) { + // auth should not be triggered + var gts []string + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + cases := []struct { + hp test.HandlerPath + cd *ConnectData + prof profiles.Profile + protos []protocol.Protocol + ptcp bool + err error + }{ + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + }, + cd: nil, + err: ErrNoProtocols, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + }, + cd: nil, + protos: []protocol.Protocol{protocol.Unknown}, + err: ErrUnknownProtocol, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + }, + cd: nil, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard, protocol.Unknown}, + err: ErrUnknownProtocol, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "openvpn", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: "openvpnconfig\nscript-security 0", + Protocol: protocol.OpenVPN, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + }, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "wireguard", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: `\[Interface\] +PrivateKey = .*`, + Protocol: protocol.WireGuard, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + }, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "wireguard+tcp", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: `\[Interface\] +PrivateKey = .*`, + Protocol: protocol.WireGuardProxy, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + }, + ptcp: true, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + } + + for _, c := range cases { + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp}) + defer srv.Close() + gcd, err := a.Connect(context.Background(), c.prof, c.protos, c.ptcp) + // got error but the want error is nil + if err != nil { + if c.err == nil { + t.Fatalf("failed connect: %v but want no error", err) + } + + if !errors.Is(err, c.err) { + t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err) + } + } else if c.err != nil { + t.Fatalf("got no error but want error: %T", c.err) + } + + if gcd != nil && c.cd != nil { + m, err := regexp.MatchString(c.cd.Configuration, gcd.Configuration) + if err != nil { + t.Fatalf("failed matching regexp: %v", err) + } + if !m { + t.Fatalf("regex:\n%s\ndoes not match config:\n%s", c.cd.Configuration, gcd.Configuration) + } + // we have already checked the config using a regex + c.cd.Configuration = gcd.Configuration + + } + if !reflect.DeepEqual(gcd, c.cd) { + t.Fatalf("got connect data: %v, not equal to want: %v", gcd, c.cd) + } + } +} + +func TestDisconnect(t *testing.T) { + var gts []string + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{ + { + Path: "/test-api-endpoint/disconnect", + ResponseHandler: disconnectHandler(t), + }, + }) + defer srv.Close() + err := a.Disconnect(context.Background()) + if err != nil { + t.Fatalf("failed /disconnect: %v", err) + } +} diff --git a/internal/eduvpnapi/endpoints/endpoints.go b/internal/eduvpnapi/endpoints/endpoints.go new file mode 100644 index 0000000..c98d2c7 --- /dev/null +++ b/internal/eduvpnapi/endpoints/endpoints.go @@ -0,0 +1,62 @@ +// Package endpoints defines a wrapper around the various +// endpoints returned by an eduVPN server in well-known +package endpoints + +import ( + "fmt" + "net/url" +) + +// List is the list of endpoints as returned by the eduVPN server +type List struct { + // API is the API endpoint which we use for calls such as /info, /connect, ... + API string `json:"api_endpoint"` + // Authorization is the authorization endpoint for OAuth + Authorization string `json:"authorization_endpoint"` + // Token is the token endpoint for OAuth + Token string `json:"token_endpoint"` +} + +// Versions is the endpoints separated by API version +type Versions struct { + // V2 is the legacy V2 API, this is not used + V2 List `json:"http://eduvpn.org/api#2"` + // V3 is the newest API, which we use + V3 List `json:"http://eduvpn.org/api#3"` +} + +// Endpoints defines the json format for /.well-known/vpn-user-portal". +type Endpoints struct { + // API defines the API endpoints, split by version + API Versions `json:"api"` + // V is the version string for the server + V string `json:"v"` +} + +// Validate validates the endpoints by parsing them and checking the scheme is HTTP +// An error is returned if they are not valid +func (e Endpoints) Validate() error { + v3 := e.API.V3 + pAPI, err := url.Parse(v3.API) + if err != nil { + return fmt.Errorf("failed to parse API endpoint: %w", err) + } + pAuth, err := url.Parse(v3.Authorization) + if err != nil { + return fmt.Errorf("failed to parse API authorization endpoint: %w", err) + } + pToken, err := url.Parse(v3.Token) + if err != nil { + return fmt.Errorf("failed to parse API token endpoint: %w", err) + } + if pAPI.Scheme != "https" { + return fmt.Errorf("API Scheme: '%s', is not equal to HTTPS", pAPI.Scheme) + } + if pAPI.Scheme != pAuth.Scheme { + return fmt.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme) + } + if pAPI.Scheme != pToken.Scheme { + return fmt.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme) + } + return nil +} diff --git a/internal/eduvpnapi/profiles/profiles.go b/internal/eduvpnapi/profiles/profiles.go new file mode 100644 index 0000000..77109f1 --- /dev/null +++ b/internal/eduvpnapi/profiles/profiles.go @@ -0,0 +1,119 @@ +// Package profiles defines a wrapper around the various profiles +// returned by the /info endpoint +package profiles + +import ( + "codeberg.org/eduVPN/eduvpn-common/types/protocol" + "codeberg.org/eduVPN/eduvpn-common/types/server" +) + +// Profile is the information for a profile +type Profile struct { + // ID is the identifier of the profile + // Used to select a profile + ID string `json:"profile_id"` + // DisplayName defines the UI friendly name for the profile + DisplayName string `json:"display_name"` + // VPNProtoList defines the list of VPN protocols + // E.g. wireguard, openvpn + VPNProtoList []string `json:"vpn_proto_list"` + // VPNProtoTransportList defines the list of VPN protocols including their transport values + // E.g. wireguard+udp, openvpn+tcp + VPNProtoTransportList []string `json:"vpn_proto_transport_list"` + // DefaultGateway specifies whether or not this profile is a default gateway profile + DefaultGateway bool `json:"default_gateway"` + // DNSSearchDomains specifies the list of dns search domains + // This is provided for a Linux client issue + // See: https://github.com/eduvpn/python-eduvpn-client/issues/550 + DNSSearchDomains []string `json:"dns_search_domain_list"` + // Priority is the priority of the profile for sorting in the UI + // the higher the priority, the higher it should be in the list + Priority int `json:"profile_priority"` +} + +// ListInfo is the struct that has the profile list +type ListInfo struct { + ProfileList []Profile `json:"profile_list"` +} + +// Info is the top-level struct for the info endpoint +type Info struct { + Info ListInfo `json:"info"` +} + +// Len returns the length of the profile list +func (i Info) Len() int { + return len(i.Info.ProfileList) +} + +// Get returns a profile with id `id`, it returns nil if it is not found +func (i Info) Get(id string) *Profile { + for _, p := range i.Info.ProfileList { + if p.ID == id { + return &p + } + } + return nil +} + +// MustIndex gets a profile by index +// This index must be in the bounds +func (i Info) MustIndex(n int) Profile { + return i.Info.ProfileList[n] +} + +func hasProtocol(protos []string, proto protocol.Protocol) bool { + for _, p := range protos { + if protocol.New(p) == proto { + return true + } + } + return false +} + +// ShouldFailover returns whether or not this VPN profile should start a failover procedure +// This is true when the profile supports a TCP connection +// If we cannot determine whether it supports a TCP connection +// (because the server doesn't provide the VPN transport list function yet), +// we will just check if it supports OpenVPN +func (p *Profile) ShouldFailover() bool { + // old servers don't support it, only failover in case OpenVPN is supported + if len(p.VPNProtoTransportList) == 0 { + // this checks VPNProtoList + return p.HasOpenVPN() + } + for _, c := range p.VPNProtoTransportList { + if c == "wireguard+tcp" { + return true + } + if c == "openvpn+tcp" { + return true + } + } + return false +} + +// HasOpenVPN returns whether or not the profile has OpenVPN support +func (p *Profile) HasOpenVPN() bool { + return hasProtocol(p.VPNProtoList, protocol.OpenVPN) +} + +// HasWireGuard returns whether or not the profile has WireGuard support +func (p *Profile) HasWireGuard() bool { + return hasProtocol(p.VPNProtoList, protocol.WireGuard) +} + +// Public gets the server list as a structure that we return to clients +func (i Info) Public() server.Profiles { + m := make(map[string]server.Profile) + for _, p := range i.Info.ProfileList { + m[p.ID] = server.Profile{ + DisplayName: map[string]string{ + "en": p.DisplayName, + }, + DefaultGateway: p.DefaultGateway, + Priority: p.Priority, + } + } + return server.Profiles{Map: m} +} diff --git a/internal/eduvpnapi/redirect.go b/internal/eduvpnapi/redirect.go new file mode 100644 index 0000000..7af31fb --- /dev/null +++ b/internal/eduvpnapi/redirect.go @@ -0,0 +1,28 @@ +package eduvpnapi + +// customRedirects supplies redirect values that should be handled by the app itself +// here we hardcode the redirect values that we should use in the OAuth requests +// these values were taken from https://codeberg.org/eduVPN/vpn-user-portal/src/branch/v3/src/OAuth/VpnClientDb.php +var customRedirects = map[string]string{ + "org.letsconnect-vpn.app.macos": "org.letsconnect-vpn.app.macos:/api/callback", + "org.letsconnect-vpn.app.ios": "org.letsconnect-vpn.app.ios:/api/callback", + "org.letsconnect-vpn.app.android": "org.letsconnect-vpn.app.android:/api/callback", + "org.eduvpn.app.macos": "org.eduvpn.app.macos:/api/callback", + "org.eduvpn.app.ios": "org.eduvpn.app.ios:/api/callback", + "org.eduvpn.app.android": "org.eduvpn.app.android:/api/callback", + "org.govvpn.app.macos": "org.govvpn.app.macos:/api/callback", + "org.govvpn.app.ios": "org.govvpn.app.ios:/api/callback", + "org.govvpn.app.android": "org.govvpn.app.android:/api/callback", +} + +// customRedirect returns the custom redirect string for the clientID `cid` +// Empty string if none is defined or one is defined but is empty. +// In both empty string cases, eduvpn-common handles the redirects as 127.0.0.1 local server redirects +// If a non-empty string is returned, the redirect should be handled by the client and we only use the redirect URI value in our OAuth requests +func customRedirect(cid string) string { + v, ok := customRedirects[cid] + if !ok { + return "" + } + return v +} -- cgit v1.2.3