From a30ef6b27e578a4cf0a674b24f5b52b4c1516c63 Mon Sep 17 00:00:00 2001 From: Jeroen Wijenbergh Date: Thu, 12 Feb 2026 12:34:08 +0100 Subject: All: Rename packages that sound useless or clash with std --- Makefile | 2 +- client/client.go | 16 +- client/client_test.go | 4 +- cmd/eduvpn-cli/main.go | 4 +- exports/exports_test_wrapper.go | 10 +- i18n/err/i18nerr.go | 4 +- internal/api/api.go | 395 ----------------------- internal/api/api_test.go | 513 ------------------------------ internal/api/cache.go | 67 ---- internal/api/endpoints/endpoints.go | 62 ---- internal/api/profiles/profiles.go | 119 ------- internal/api/redirect.go | 28 -- internal/commonver/commonver.go | 9 + internal/discovery/discovery.go | 18 +- internal/eduvpnapi/cache.go | 67 ++++ internal/eduvpnapi/eduvpnapi.go | 395 +++++++++++++++++++++++ internal/eduvpnapi/eduvpnapi_test.go | 513 ++++++++++++++++++++++++++++++ internal/eduvpnapi/endpoints/endpoints.go | 62 ++++ internal/eduvpnapi/profiles/profiles.go | 119 +++++++ internal/eduvpnapi/redirect.go | 28 ++ internal/http/http.go | 288 ----------------- internal/http/http_test.go | 61 ---- internal/httpwrap/httpwrap.go | 288 +++++++++++++++++ internal/httpwrap/httpwrap_test.go | 61 ++++ internal/log/log.go | 37 --- internal/log/rotate.go | 109 ------- internal/log/rotate_test.go | 144 --------- internal/loglevel/loglevel.go | 37 +++ internal/loglevel/rotate.go | 109 +++++++ internal/loglevel/rotate_test.go | 144 +++++++++ internal/server/custom.go | 10 +- internal/server/institute.go | 10 +- internal/server/secureinternet.go | 10 +- internal/server/server.go | 10 +- internal/server/servers.go | 4 +- internal/test/server.go | 6 +- internal/version/version.go | 9 - make_release.sh | 4 +- prepare_release.sh | 4 +- proxy/proxy.go | 4 +- upload_release.sh | 2 +- wrappers/python/Makefile | 2 +- 42 files changed, 1894 insertions(+), 1894 deletions(-) delete mode 100644 internal/api/api.go delete mode 100644 internal/api/api_test.go delete mode 100644 internal/api/cache.go delete mode 100644 internal/api/endpoints/endpoints.go delete mode 100644 internal/api/profiles/profiles.go delete mode 100644 internal/api/redirect.go create mode 100644 internal/commonver/commonver.go create mode 100644 internal/eduvpnapi/cache.go create mode 100644 internal/eduvpnapi/eduvpnapi.go create mode 100644 internal/eduvpnapi/eduvpnapi_test.go create mode 100644 internal/eduvpnapi/endpoints/endpoints.go create mode 100644 internal/eduvpnapi/profiles/profiles.go create mode 100644 internal/eduvpnapi/redirect.go delete mode 100644 internal/http/http.go delete mode 100644 internal/http/http_test.go create mode 100644 internal/httpwrap/httpwrap.go create mode 100644 internal/httpwrap/httpwrap_test.go delete mode 100644 internal/log/log.go delete mode 100644 internal/log/rotate.go delete mode 100644 internal/log/rotate_test.go create mode 100644 internal/loglevel/loglevel.go create mode 100644 internal/loglevel/rotate.go create mode 100644 internal/loglevel/rotate_test.go delete mode 100644 internal/version/version.go diff --git a/Makefile b/Makefile index 7b94913..494ebd9 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ .PHONY: build docs fmt lint cli test clean coverage sloc -VERSION := $(shell grep -o 'const Version = "[^"]*' internal/version/version.go | cut -d '"' -f 2) +VERSION := $(shell grep -o 'const Version = "[^"]*' internal/commonver/commonver.go | cut -d '"' -f 2) build: diff --git a/client/client.go b/client/client.go index ae468ed..ed1f673 100644 --- a/client/client.go +++ b/client/client.go @@ -12,13 +12,13 @@ import ( "time" "codeberg.org/eduVPN/eduvpn-common/i18n/err" - "codeberg.org/eduVPN/eduvpn-common/internal/api" "codeberg.org/eduVPN/eduvpn-common/internal/config" "codeberg.org/eduVPN/eduvpn-common/internal/discovery" + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi" "codeberg.org/eduVPN/eduvpn-common/internal/failover" "codeberg.org/eduVPN/eduvpn-common/internal/fsm" - "codeberg.org/eduVPN/eduvpn-common/internal/http" - "codeberg.org/eduVPN/eduvpn-common/internal/log" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" + "codeberg.org/eduVPN/eduvpn-common/internal/loglevel" "codeberg.org/eduVPN/eduvpn-common/internal/server" "codeberg.org/eduVPN/eduvpn-common/types/cookie" srvtypes "codeberg.org/eduVPN/eduvpn-common/types/server" @@ -165,7 +165,7 @@ func New(name string, version string, directory string, stateCallback func(FSMSt // Initialize provided logger or use default if logger == nil { - logger = &log.Logger{} + logger = &loglevel.Logger{} } c.logr = logger slogger, err := c.logr.Init(directory) @@ -178,7 +178,7 @@ func New(name string, version string, directory string, stateCallback func(FSMSt c.Name = name // register HTTP agent - http.RegisterAgent(userAgentName(name), version) + httpwrap.RegisterAgent(userAgentName(name), version) // Initialize the FSM c.FSM = newFSM(stateCallback) @@ -398,7 +398,7 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. } if _type != srvtypes.TypeSecureInternet { // Convert to an identifier - identifier, err = http.EnsureValidURL(identifier, true) + identifier, err = httpwrap.EnsureValidURL(identifier, true) if err != nil { return i18nerr.WrapInternalf(err, "failed to convert identifier: %v", identifier) } @@ -432,7 +432,7 @@ func (c *Client) convertIdentifier(identifier string, t srvtypes.Type) (string, return identifier, nil } // Convert to an identifier, this also converts the scheme to HTTPS - identifier, err := http.EnsureValidURL(identifier, true) + identifier, err := httpwrap.EnsureValidURL(identifier, true) if err != nil { return "", i18nerr.Wrapf(err, "The input: '%s' is not a valid URL", identifier) } @@ -511,7 +511,7 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. } if err != nil { if startup { - if errors.Is(err, api.ErrAuthorizeDisabled) { + if errors.Is(err, eduvpnapi.ErrAuthorizeDisabled) { return nil, i18nerr.Newf("The client tried to autoconnect to the VPN server: '%s', but you need to authorize again. Please manually connect again.", identifier) } return nil, i18nerr.Wrapf(err, "The client tried to autoconnect to the VPN server: '%s', but the operation failed to complete", identifier) diff --git a/client/client_test.go b/client/client_test.go index dd69688..6d34da3 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -10,7 +10,7 @@ import ( "testing" "time" - httpw "codeberg.org/eduVPN/eduvpn-common/internal/http" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" "codeberg.org/eduVPN/eduvpn-common/internal/test" "codeberg.org/eduVPN/eduvpn-common/types/cookie" "codeberg.org/eduVPN/eduvpn-common/types/protocol" @@ -98,7 +98,7 @@ func getServerURI(t *testing.T) string { if serverURI == "" { t.Skip("Skipping server test as no SERVER_URI env var has been passed") } - serverURI, parseErr := httpw.EnsureValidURL(serverURI, true) + serverURI, parseErr := httpwrap.EnsureValidURL(serverURI, true) if parseErr != nil { t.Skip("Skipping server test as the server uri is not valid") } diff --git a/cmd/eduvpn-cli/main.go b/cmd/eduvpn-cli/main.go index bb90e40..db255cd 100644 --- a/cmd/eduvpn-cli/main.go +++ b/cmd/eduvpn-cli/main.go @@ -11,7 +11,7 @@ import ( "codeberg.org/eduVPN/eduvpn-common/client" "codeberg.org/eduVPN/eduvpn-common/i18n" - "codeberg.org/eduVPN/eduvpn-common/internal/version" + "codeberg.org/eduVPN/eduvpn-common/internal/commonver" "codeberg.org/eduVPN/eduvpn-common/types/cookie" srvtypes "codeberg.org/eduVPN/eduvpn-common/types/server" @@ -148,7 +148,7 @@ func printConfig(url string, cc string, srvType srvtypes.Type, prof string) erro defer os.RemoveAll(dir) //nolint:errcheck c, err = client.New( "org.eduvpn.app.linux", - fmt.Sprintf("%s-cli", version.Version), + fmt.Sprintf("%s-cli", commonver.Version), dir, func(oldState client.FSMStateID, newState client.FSMStateID, data any) bool { stateCallback(oldState, newState, data, prof, dir) diff --git a/exports/exports_test_wrapper.go b/exports/exports_test_wrapper.go index 013a7a3..eec5a12 100644 --- a/exports/exports_test_wrapper.go +++ b/exports/exports_test_wrapper.go @@ -25,7 +25,7 @@ import ( "codeberg.org/eduVPN/eduvpn-common/internal/test" "codeberg.org/eduVPN/eduvpn-common/types/error" - httpw "codeberg.org/eduVPN/eduvpn-common/internal/http" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" ) func getString(in *C.char) string { @@ -148,7 +148,7 @@ func fakeBrowserAuth(str string) (string, error) { } func testServer(t *testing.T) *test.Server { - // TODO: duplicate code between this and internal/api/api_test.go + // TODO: duplicate code between this and internal/eduvpnapi/eduvpnapi_test.go listen, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to setup listener for test server: %v", err) @@ -266,7 +266,7 @@ func testServerList(t *testing.T) { t.Fatalf("failed to obtain server client: %v", err) } - httpw.DefaultTransport = sclient.Client.Transport.(*http.Transport) + httpwrap.DefaultTransport = sclient.Client.Transport.(*http.Transport) gerr := getError(t, AddServer(ck, 3, listS, nil)) if gerr != "" { @@ -418,7 +418,7 @@ func testGetConfig(t *testing.T) { t.Fatalf("failed to obtain server client: %v", err) } - httpw.DefaultTransport = sclient.Client.Transport.(*http.Transport) + httpwrap.DefaultTransport = sclient.Client.Transport.(*http.Transport) _, cfgErr := GetConfig(ck, 3, listS, 0, 0) cfgErrS := getError(t, cfgErr) @@ -501,7 +501,7 @@ func testLetsConnectDiscovery(t *testing.T) { t.Fatalf("failed to obtain server client: %v", err) } - httpw.DefaultTransport = sclient.Client.Transport.(*http.Transport) + httpwrap.DefaultTransport = sclient.Client.Transport.(*http.Transport) // try to add an institute access server exptErr := fmt.Sprintf("An internal error occurred. The cause of the error is: Adding a non-custom server when the client does not use discovery is not supported, identifier: %s, type: 1.", list) diff --git a/i18n/err/i18nerr.go b/i18n/err/i18nerr.go index 8254dd4..1c97462 100644 --- a/i18n/err/i18nerr.go +++ b/i18n/err/i18nerr.go @@ -8,7 +8,7 @@ import ( "log/slog" "sync" - "codeberg.org/eduVPN/eduvpn-common/internal/http" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" "golang.org/x/text/language" "golang.org/x/text/message" @@ -26,7 +26,7 @@ func TranslatedInner(inner error) (string, bool) { unwrapped = errors.Unwrap(unwrapped) } - var tErr *http.TimeoutError + var tErr *httpwrap.TimeoutError switch { case errors.As(inner, &tErr): return printerOrNew(language.English).Sprintf("Timeout reached contacting URL: '%s'", tErr.URL), false diff --git a/internal/api/api.go b/internal/api/api.go deleted file mode 100644 index 0d8e03c..0000000 --- a/internal/api/api.go +++ /dev/null @@ -1,395 +0,0 @@ -// Package api implements version 3 of the eduVPN api: https://docs.eduvpn.org/server/v3/api.html -package api - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - "net/http" - "net/url" - "time" - - "codeberg.org/jwijenbergh/eduoauth-go/v2" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "codeberg.org/eduVPN/eduvpn-common/internal/api/endpoints" - "codeberg.org/eduVPN/eduvpn-common/internal/api/profiles" - httpw "codeberg.org/eduVPN/eduvpn-common/internal/http" - "codeberg.org/eduVPN/eduvpn-common/internal/wireguard" - "codeberg.org/eduVPN/eduvpn-common/types/protocol" - "codeberg.org/eduVPN/eduvpn-common/types/server" -) - -// Callbacks is the API callback interface -// It is used to trigger authorization and forward token updates -type Callbacks interface { - // TriggerAuth is called when authorization should be triggered - TriggerAuth(context.Context, string, bool) (string, error) - // AuthDone is called when authorization has just completed - AuthDone(string, server.Type) - // TokensUpdates is called when tokens are updated - TokensUpdated(string, server.Type, eduoauth.Token) -} - -// ServerData is the data for a server that is passed to the API struct -type ServerData struct { - // ID is the identifier for the server - ID string - // Type is the type of server - Type server.Type - // BaseWK is the base well-known endpoint - BaseWK string - // BaseAuthWK is the base well-known endpoint for authorization. This is only different in case of secure internet - BaseAuthWK string - // ProcessAuth processes the OAuth authorization - ProcessAuth func(context.Context, string) (string, error) - // DisableAuthorize indicates whether or not new authorization requests should be disabled - DisableAuthorize bool - // transport is the HTTP transport, only used for testing currently - transport http.RoundTripper -} - -// Transport returns the transport to be used for the server -// By default it uses the transport from internal/http DefaultTransport -func (s *ServerData) Transport() http.RoundTripper { - if s.transport == nil { - return httpw.DefaultTransport - } - return s.transport -} - -// API is the top-level struct that each method is defined on -type API struct { - cb Callbacks - // oauth is the oauth object - oauth *eduoauth.OAuth - // Data is the server data - Data ServerData -} - -// NewAPI creates a new API object by creating an OAuth object -func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, tokens *eduoauth.Token) (*API, error) { - cr := customRedirect(clientID) - // Construct OAuth - - transp := sd.Transport() - post := true - // we do not support non-loopback clients with response_mode form_post - if cr != "" { - post = false - } - o := eduoauth.OAuth{ - ClientID: clientID, - EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) { - ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, transp) - if err != nil { - return nil, err - } - return &eduoauth.EndpointResponse{ - AuthorizationURL: ep.API.V3.Authorization, - TokenURL: ep.API.V3.Token, - }, nil - }, - CustomRedirect: cr, - FormPost: post, - RedirectPath: "/callback", - TokensUpdated: func(tok eduoauth.Token) { - cb.TokensUpdated(sd.ID, sd.Type, tok) - }, - Transport: transp, - UserAgent: httpw.UserAgent, - } - - if tokens != nil { - o.UpdateTokens(*tokens) - } - - api := &API{ - cb: cb, - oauth: &o, - Data: sd, - } - err := api.authorize(ctx) - if err != nil { - return nil, err - } - return api, nil -} - -// ErrAuthorizeDisabled is returned when authorization is disabled but is needed to complete -var ErrAuthorizeDisabled = errors.New("cannot authorize as re-authorization is disabled") - -func (a *API) authorize(ctx context.Context) (err error) { - _, err = a.oauth.AccessToken(ctx) - // already authorized - if err == nil { - return nil - } - - // otherwise check if invalid tokens, - // if not then something else is wrong with the API - // return an error - tErr := &eduoauth.TokensInvalidError{} - if !errors.As(err, &tErr) { - return err - } - - if a.Data.DisableAuthorize { - return ErrAuthorizeDisabled - } - - defer func() { - if err == nil { - a.cb.AuthDone(a.Data.ID, a.Data.Type) - } - }() - - scope := "config" - url, err := a.oauth.AuthURL(ctx, scope) - if err != nil { - return err - } - if a.Data.ProcessAuth != nil { - url, err = a.Data.ProcessAuth(ctx, url) - if err != nil { - return err - } - } - // We expect an uri if custom redirect is non empty - uri, err := a.cb.TriggerAuth(ctx, url, a.oauth.CustomRedirect != "") - if err != nil { - return err - } - // The uri is only given here if a custom redirect is done - err = a.oauth.Exchange(ctx, uri) - if err != nil { - return err - } - return nil -} - -func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { - ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport()) - if err != nil { - return nil, nil, err - } - u := ep.API.V3.API + endpoint - - // TODO: Cache HTTP client? - httpC := httpw.NewClient(a.oauth.NewHTTPClient()) - return httpC.Do(ctx, method, u, opts) -} - -func (a *API) authorizedRetry(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { - h, body, err := a.authorized(ctx, method, endpoint, opts) - if err == nil { - return h, body, nil - } - - statErr := &httpw.StatusError{} - // Only retry authorized if we get an HTTP 401 - // TODO: Can the OAuth client handle this instead? - if errors.As(err, &statErr) && statErr.Status == 401 { - slog.Debug("Got a HTTP 401. Marking tokens as expired...", "HTTP method", method, "endpoint", endpoint) - // Mark the token as expired and retry, so we trigger the refresh flow - a.oauth.SetTokenExpired() - h, body, err = a.authorized(ctx, method, endpoint, opts) - } - // Tokens is invalid we need to renew and authorize again - tErr := &eduoauth.TokensInvalidError{} - if err != nil && errors.As(err, &tErr) { - // Mark the token as invalid and retry, so we trigger the authorization flow - a.oauth.SetTokenRenew() - slog.Debug("The tokens were invalid, trying again...") - if autherr := a.authorize(ctx); autherr != nil { - return nil, nil, autherr - } - return a.authorized(ctx, method, endpoint, opts) - } - return h, body, err -} - -// Disconnect disconnects a client from the server by sending a /disconnect API call -// This cleans up resources such as WireGuard IP allocation -func (a *API) Disconnect(ctx context.Context) error { - _, _, err := a.authorized(ctx, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second}) - return err -} - -// Info does the /info API call -func (a *API) Info(ctx context.Context) (*profiles.Info, error) { - _, body, err := a.authorizedRetry(ctx, http.MethodGet, "/info", nil) - if err != nil { - return nil, fmt.Errorf("failed API /info: %w", err) - } - p := profiles.Info{} - if err = json.Unmarshal(body, &p); err != nil { - return nil, fmt.Errorf("failed API /info: %w", err) - } - return &p, nil -} - -// ConnectData is the data that is returned when the /connect call completes without error -type ConnectData struct { - // Configuration is the VPN configuration - Configuration string - // Protocol tells us what protocol it is, OpenVPN or WireGuard (proxied or not) - Protocol protocol.Protocol - // Expires tells us when this configuration expires - Expires time.Time -} - -// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 -func boolToYesNo(preferTCP bool) string { - if preferTCP { - return "yes" - } - return "no" -} - -func protocolFromCT(ct string) (protocol.Protocol, error) { - switch ct { - case "application/x-wireguard-profile": - return protocol.WireGuard, nil - case "application/x-wireguard+tcp-profile": - return protocol.WireGuardProxy, nil - case "application/x-openvpn-profile": - return protocol.OpenVPN, nil - } - return protocol.Unknown, fmt.Errorf("invalid content type: %s", ct) -} - -// ErrNoProtocols is returned when a connect call is given with an empty protocol slice -var ErrNoProtocols = errors.New("no protocols supplied") - -// ErrUnknownProtocol is returned when the client in a connect gives an unknown protocol -var ErrUnknownProtocol = errors.New("unknown protocol supplied") - -// Connect sends a /connect to an eduVPN server -// `ctx` is the context used for cancellation -// protos is the list of protocols supported and wanted by the client -func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []protocol.Protocol, pTCP bool) (*ConnectData, error) { - hdrs := http.Header{ - "content-type": {"application/x-www-form-urlencoded"}, - } - uv := url.Values{ - "profile_id": {prof.ID}, - } - - if len(protos) == 0 { - return nil, ErrNoProtocols - } - - var wgKey *wgtypes.Key - - // Loop over the protocols and set the correct headers and values - for _, p := range protos { - switch p { - case protocol.WireGuard: - gk, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, err - } - wgKey = &gk - // Set the public key - pubkey := wgKey.PublicKey() - uv.Set("public_key", pubkey.String()) - hdrs.Add("accept", "application/x-wireguard-profile") - hdrs.Add("accept", "application/x-wireguard+tcp-profile") - case protocol.OpenVPN: - hdrs.Add("accept", "application/x-openvpn-profile") - default: - return nil, ErrUnknownProtocol - } - } - // set prefer TCP - uv.Set("prefer_tcp", boolToYesNo(pTCP)) - - // Construct the parameters - params := &httpw.OptionalParams{Headers: hdrs, Body: uv} - h, body, err := a.authorizedRetry(ctx, http.MethodPost, "/connect", params) - if err != nil { - return nil, fmt.Errorf("failed API /connect call: %w", err) - } - - // Parse expiry - expH := h.Get("expires") - if expH == "" { - return nil, errors.New("the server did not give an expires header") - } - expT, err := http.ParseTime(expH) - if err != nil { - return nil, fmt.Errorf("failed parsing expiry time: %w", err) - } - - vpnCfg := string(body) - // Parse content type - contentH := h.Get("content-type") - proto, err := protocolFromCT(contentH) - if err != nil { - return nil, err - } - - if proto == protocol.OpenVPN { - // ensure scripts are not ran by default by append script-security 0 to the config - vpnCfg += "\nscript-security 0" - return &ConnectData{ - Configuration: vpnCfg, - Protocol: proto, - Expires: expT, - }, nil - } - - vpnCfg, err = wireguard.Config(vpnCfg, wgKey) - if err != nil { - return nil, err - } - return &ConnectData{ - Configuration: vpnCfg, - Protocol: proto, - Expires: expT, - }, nil -} - -func getEndpoints(ctx context.Context, url string, tp http.RoundTripper) (*endpoints.Endpoints, error) { - uStr, err := httpw.JoinURLPath(url, "/.well-known/vpn-user-portal") - if err != nil { - return nil, err - } - httpC := httpw.NewClient(nil) - httpC.Client.Transport = tp - _, body, err := httpC.Get(ctx, uStr) - if err != nil { - return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) - } - - ep := endpoints.Endpoints{} - if err = json.Unmarshal(body, &ep); err != nil { - return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) - } - err = ep.Validate() - if err != nil { - return nil, err - } - return &ep, nil -} - -// OAuthLogger is defined here to update the internal logger -// for the eduoauth library -type OAuthLogger struct{} - -// Logf logs a message with parameters -func (ol *OAuthLogger) Logf(msg string, params ...any) { - slog.Debug("OAuth log", "log", fmt.Sprintf(msg, params...)) -} - -// Log logs a message -func (ol *OAuthLogger) Log(msg string) { - slog.Debug("OAuth log", "log", msg) -} - -func init() { - eduoauth.UpdateLogger(&OAuthLogger{}) -} diff --git a/internal/api/api_test.go b/internal/api/api_test.go deleted file mode 100644 index e88e816..0000000 --- a/internal/api/api_test.go +++ /dev/null @@ -1,513 +0,0 @@ -package api - -import ( - "context" - "errors" - "fmt" - "io" - "net" - "net/http" - "net/url" - "reflect" - "regexp" - "slices" - "strings" - "testing" - "time" - - "codeberg.org/eduVPN/eduvpn-common/internal/api/profiles" - httpw "codeberg.org/eduVPN/eduvpn-common/internal/http" - "codeberg.org/eduVPN/eduvpn-common/internal/test" - "codeberg.org/eduVPN/eduvpn-common/types/protocol" - "codeberg.org/eduVPN/eduvpn-common/types/server" - "codeberg.org/jwijenbergh/eduoauth-go/v2" -) - -func tokenHandler(t *testing.T, gt []string) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - t.Fatalf("invalid HTTP method for token handler: %v", r.Method) - } - b, err := io.ReadAll(r.Body) - if err != nil { - t.Fatalf("failed reading token endpoint body: %v", err) - } - parsed, err := url.ParseQuery(string(b)) - if err != nil { - t.Fatalf("failed parsing query body: %v", err) - } - grant := parsed.Get("grant_type") - - if slices.Contains(gt, grant) { - _, err = w.Write([]byte(` -{ - "access_token": "validaccess", - "refresh_token": "validrefresh", - "expires_in": 3600 -} - `)) - if err != nil { - t.Fatalf("failed writing in token handler: %v", err) - } - return - } - t.Fatalf("grant type: %v, not allowed", grant) - } -} - -func checkAuthBearer(t *testing.T, r *http.Request) { - authh := r.Header.Get("Authorization") - if !strings.HasPrefix(authh, "Bearer ") { - t.Fatalf("API call is not given with an authorization Bearer header, got: %v", authh) - } -} - -func connectHandler(t *testing.T, proto string, exp time.Time) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - t.Fatalf("invalid HTTP method for connect handler: %v", r.Method) - } - checkAuthBearer(t, r) - w.Header().Set("expires", exp.Format(http.TimeFormat)) - w.Header().Set("content-type", fmt.Sprintf("application/x-%s-profile", proto)) - b, err := io.ReadAll(r.Body) - if err != nil { - t.Fatalf("failed reading token endpoint body: %v", err) - } - parsed, err := url.ParseQuery(string(b)) - if err != nil { - t.Fatalf("failed parsing query body: %v", err) - } - // the wireguard config we parse - var cfg string - if proto == "openvpn" { - cfg = "openvpnconfig" - } else { - if parsed.Get("public_key") == "" { - t.Fatalf("no public_key given") - } - if proto == "wireguard+tcp" { - ptcp := parsed.Get("prefer_tcp") - if ptcp != "yes" { - t.Fatalf("prefer TCP is not yes: %s", ptcp) - } - cfg = ` -[Interface] -[Peer] -ProxyEndpoint = https://proxyendpoint -` - } else { - cfg = "[Interface]" - } - } - _, err = w.Write([]byte(cfg)) - if err != nil { - t.Fatalf("failed writing /connect response: %v", err) - } - } -} - -func disconnectHandler(t *testing.T) func(http.ResponseWriter, *http.Request) { - return func(_ http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - t.Fatalf("invalid HTTP method for disconnect handler: %v", r.Method) - } - checkAuthBearer(t, r) - } -} - -type TestCallback struct { - t *testing.T -} - -func (tc *TestCallback) TriggerAuth(_ context.Context, str string, _ bool) (string, error) { - go func() { - u, err := url.Parse(str) - if err != nil { - panic(err) - } - ru, err := url.Parse(u.Query().Get("redirect_uri")) - if err != nil { - panic(err) - } - oq := u.Query() - q := ru.Query() - q.Set("state", oq.Get("state")) - q.Set("code", "fakeauthcode") - ru.RawQuery = q.Encode() - - c := http.Client{} - req, err := http.NewRequest("GET", ru.String(), nil) - if err != nil { - panic(err) - } - _, err = c.Do(req) - if err != nil { - panic(err) - } - }() - return "", nil -} -func (tc *TestCallback) AuthDone(string, server.Type) {} -func (tc *TestCallback) TokensUpdated(string, server.Type, eduoauth.Token) {} - -// create a API struct with allowed grant types -func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.HandlerPath) (*API, *test.Server) { - // Create a simple API client and check if the fields are created correctly - listen, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("failed to setup listener for test server: %v", err) - } - - hps = append(hps, []test.HandlerPath{ - { - Method: http.MethodGet, - Path: "/.well-known/vpn-user-portal", - Response: fmt.Sprintf(` -{ - "api": { - "http://eduvpn.org/api#3": { - "api_endpoint": "https://%[1]s/test-api-endpoint", - "authorization_endpoint": "https://%[1]s/test-authorization-endpoint", - "token_endpoint": "https://%[1]s/test-token-endpoint" - } - }, - "v": "0.0.0" -} -`, listen.Addr().String()), - }, - { - Path: "/test-token-endpoint", - ResponseHandler: tokenHandler(t, gt), - }, - }...) - // start server - serv := test.NewServerWithHandles(hps, listen) - servc, err := serv.Client() - if err != nil { - t.Fatalf("failed to setup HTTP test server client: %v", servc) - } - - sd := ServerData{ - ID: "randomidentifier", - Type: server.TypeCustom, - BaseWK: serv.URL, - BaseAuthWK: serv.URL, - ProcessAuth: func(_ context.Context, in string) (string, error) { - return in, nil - }, - DisableAuthorize: false, - transport: servc.Client.Transport, - } - - tc := &TestCallback{t: t} - - a, err := NewAPI(context.Background(), "testclient", sd, tc, tok) - if err != nil { - t.Fatalf("failed creating API: %v", err) - } - return a, serv -} - -func TestNewAPI(t *testing.T) { - gts := []string{"refresh_token"} - tok := &eduoauth.Token{ - Access: "expiredaccess", - Refresh: "expiredrefresh", - // tokens are expired, let's try authorizing - ExpiredTimestamp: time.Now(), - } - a, srv := createTestAPI(t, tok, gts, nil) - srv.Close() - - // now the tokens should be the new access tokens - if a.oauth.Token().Access != "validaccess" { - t.Fatalf("access token is not valid access") - } - if a.oauth.Token().Refresh != "validrefresh" { - t.Fatalf("refresh token is not valid refresh") - } - - gts = []string{"authorization_code"} - tok = &eduoauth.Token{ - Access: "expiredaccess", - Refresh: "", - ExpiredTimestamp: time.Now(), - } - a, srv = createTestAPI(t, tok, gts, nil) - srv.Close() - - // now the tokens should be the new access tokens - if a.oauth.Token().Access != "validaccess" { - t.Fatalf("access token is not valid access") - } - if a.oauth.Token().Refresh != "validrefresh" { - t.Fatalf("refresh token is not valid refresh") - } -} - -func TestAPIInfo(t *testing.T) { - // auth should not be triggered - var gts []string - tok := &eduoauth.Token{ - Access: "validaccess", - Refresh: "validrefresh", - ExpiredTimestamp: time.Now().Add(1 * time.Hour), - } - statErr := &httpw.StatusError{} - cases := []struct { - hp test.HandlerPath - info *profiles.Info - err any - }{ - { - hp: test.HandlerPath{ - Method: http.MethodGet, - Path: "/test-api-endpoint/info", - Response: ` -{ - "info": { - "profile_list": [ - { - "default_gateway": false, - "display_name": "test profile 1", - "profile_id": "test1", - "profile_priority": 3, - "vpn_proto_list": [ - "openvpn", - "wireguard" - ] - } - ] - } -} -`, - }, - info: &profiles.Info{ - Info: profiles.ListInfo{ - ProfileList: []profiles.Profile{ - { - ID: "test1", - DisplayName: "test profile 1", - VPNProtoList: []string{"openvpn", "wireguard"}, - Priority: 3, - DefaultGateway: false, - }, - }, - }, - }, - }, - { - hp: test.HandlerPath{ - Method: http.MethodGet, - Path: "/test-api-endpoint/info", - Response: ` -{ - "info": { - "profile_list": [ - { - "display_name": "test profile 2", - "profile_id": "test2", - "vpn_proto_list": [ - "wireguard" - ] - } - ] - } -} -`, - }, - info: &profiles.Info{ - Info: profiles.ListInfo{ - ProfileList: []profiles.Profile{ - { - ID: "test2", - DisplayName: "test profile 2", - VPNProtoList: []string{"wireguard"}, - DefaultGateway: false, - }, - }, - }, - }, - }, - { - hp: test.HandlerPath{ - Method: http.MethodGet, - Path: "/test-api-endpoint/info", - Response: "", - ResponseCode: 404, - }, - info: nil, - err: &statErr, - }, - } - - for _, c := range cases { - a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp}) - defer srv.Close() - gprfs, err := a.Info(context.Background()) - // got error but the want error is nil - if err != nil { - if c.err == nil { - t.Fatalf("failed profiles info: %v but want no error", err) - } - - if !errors.As(err, c.err) { - t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err.Error()) - } - } else if c.err != nil { - t.Fatalf("got no error but want error: %T", c.err) - } - - if !reflect.DeepEqual(gprfs, c.info) { - t.Fatalf("got info: %v, not equal to want: %v", gprfs, c.info) - } - } -} - -func TestAPIConnect(t *testing.T) { - // auth should not be triggered - var gts []string - tok := &eduoauth.Token{ - Access: "validaccess", - Refresh: "validrefresh", - ExpiredTimestamp: time.Now().Add(1 * time.Hour), - } - cases := []struct { - hp test.HandlerPath - cd *ConnectData - prof profiles.Profile - protos []protocol.Protocol - ptcp bool - err error - }{ - { - hp: test.HandlerPath{ - Method: http.MethodPost, - Path: "/test-api-endpoint/connect", - Response: ``, - }, - cd: nil, - err: ErrNoProtocols, - }, - { - hp: test.HandlerPath{ - Method: http.MethodPost, - Path: "/test-api-endpoint/connect", - Response: ``, - }, - cd: nil, - protos: []protocol.Protocol{protocol.Unknown}, - err: ErrUnknownProtocol, - }, - { - hp: test.HandlerPath{ - Method: http.MethodPost, - Path: "/test-api-endpoint/connect", - Response: ``, - }, - cd: nil, - protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard, protocol.Unknown}, - err: ErrUnknownProtocol, - }, - { - hp: test.HandlerPath{ - Method: http.MethodPost, - Path: "/test-api-endpoint/connect", - ResponseHandler: connectHandler(t, "openvpn", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), - }, - cd: &ConnectData{ - Configuration: "openvpnconfig\nscript-security 0", - Protocol: protocol.OpenVPN, - Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), - }, - protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, - err: nil, - }, - { - hp: test.HandlerPath{ - Method: http.MethodPost, - Path: "/test-api-endpoint/connect", - ResponseHandler: connectHandler(t, "wireguard", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), - }, - cd: &ConnectData{ - Configuration: `\[Interface\] -PrivateKey = .*`, - Protocol: protocol.WireGuard, - Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), - }, - protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, - err: nil, - }, - { - hp: test.HandlerPath{ - Method: http.MethodPost, - Path: "/test-api-endpoint/connect", - ResponseHandler: connectHandler(t, "wireguard+tcp", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), - }, - cd: &ConnectData{ - Configuration: `\[Interface\] -PrivateKey = .*`, - Protocol: protocol.WireGuardProxy, - Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), - }, - ptcp: true, - protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, - err: nil, - }, - } - - for _, c := range cases { - a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp}) - defer srv.Close() - gcd, err := a.Connect(context.Background(), c.prof, c.protos, c.ptcp) - // got error but the want error is nil - if err != nil { - if c.err == nil { - t.Fatalf("failed connect: %v but want no error", err) - } - - if !errors.Is(err, c.err) { - t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err) - } - } else if c.err != nil { - t.Fatalf("got no error but want error: %T", c.err) - } - - if gcd != nil && c.cd != nil { - m, err := regexp.MatchString(c.cd.Configuration, gcd.Configuration) - if err != nil { - t.Fatalf("failed matching regexp: %v", err) - } - if !m { - t.Fatalf("regex:\n%s\ndoes not match config:\n%s", c.cd.Configuration, gcd.Configuration) - } - // we have already checked the config using a regex - c.cd.Configuration = gcd.Configuration - - } - if !reflect.DeepEqual(gcd, c.cd) { - t.Fatalf("got connect data: %v, not equal to want: %v", gcd, c.cd) - } - } -} - -func TestDisconnect(t *testing.T) { - var gts []string - tok := &eduoauth.Token{ - Access: "validaccess", - Refresh: "validrefresh", - ExpiredTimestamp: time.Now().Add(1 * time.Hour), - } - a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{ - { - Path: "/test-api-endpoint/disconnect", - ResponseHandler: disconnectHandler(t), - }, - }) - defer srv.Close() - err := a.Disconnect(context.Background()) - if err != nil { - t.Fatalf("failed /disconnect: %v", err) - } -} diff --git a/internal/api/cache.go b/internal/api/cache.go deleted file mode 100644 index 5c682f4..0000000 --- a/internal/api/cache.go +++ /dev/null @@ -1,67 +0,0 @@ -package api - -import ( - "context" - "net/http" - "sync" - "time" - - "codeberg.org/eduVPN/eduvpn-common/internal/api/endpoints" -) - -// EndpointCache is a struct that caches well-known API endpoints -type EndpointCache struct { - lastUpdate map[string]time.Time - lastEP map[string]*endpoints.Endpoints - mu sync.Mutex -} - -// Get returns a cached or fresh endpoint cache copy -func (ec *EndpointCache) Get(ctx context.Context, wk string, transport http.RoundTripper) (*endpoints.Endpoints, error) { - ec.mu.Lock() - defer ec.mu.Unlock() - - // get the last update time - lu := time.Time{} - if v, ok := ec.lastUpdate[wk]; ok { - lu = v - } - - // if not 10 minutes have passed, return cached copy - if !lu.IsZero() && !time.Now().After(lu.Add(10*time.Minute)) { - v, ok := ec.lastEP[wk] - if ok { - return v, nil - } - } - - // get fresh API endpoints - ep, err := getEndpoints(ctx, wk, transport) - if err != nil { - return nil, err - } - - // update endpoints - ec.lastUpdate[wk] = time.Now() - ec.lastEP[wk] = ep - - return ep, nil -} - -var ( - epCache *EndpointCache - epCacheOnce sync.Once -) - -// GetEndpointCache returns the global singleton endpoint cache -// or creates one if it does not exist -func GetEndpointCache() *EndpointCache { - epCacheOnce.Do(func() { - epCache = &EndpointCache{ - lastUpdate: make(map[string]time.Time), - lastEP: make(map[string]*endpoints.Endpoints), - } - }) - - return epCache -} diff --git a/internal/api/endpoints/endpoints.go b/internal/api/endpoints/endpoints.go deleted file mode 100644 index c98d2c7..0000000 --- a/internal/api/endpoints/endpoints.go +++ /dev/null @@ -1,62 +0,0 @@ -// Package endpoints defines a wrapper around the various -// endpoints returned by an eduVPN server in well-known -package endpoints - -import ( - "fmt" - "net/url" -) - -// List is the list of endpoints as returned by the eduVPN server -type List struct { - // API is the API endpoint which we use for calls such as /info, /connect, ... - API string `json:"api_endpoint"` - // Authorization is the authorization endpoint for OAuth - Authorization string `json:"authorization_endpoint"` - // Token is the token endpoint for OAuth - Token string `json:"token_endpoint"` -} - -// Versions is the endpoints separated by API version -type Versions struct { - // V2 is the legacy V2 API, this is not used - V2 List `json:"http://eduvpn.org/api#2"` - // V3 is the newest API, which we use - V3 List `json:"http://eduvpn.org/api#3"` -} - -// Endpoints defines the json format for /.well-known/vpn-user-portal". -type Endpoints struct { - // API defines the API endpoints, split by version - API Versions `json:"api"` - // V is the version string for the server - V string `json:"v"` -} - -// Validate validates the endpoints by parsing them and checking the scheme is HTTP -// An error is returned if they are not valid -func (e Endpoints) Validate() error { - v3 := e.API.V3 - pAPI, err := url.Parse(v3.API) - if err != nil { - return fmt.Errorf("failed to parse API endpoint: %w", err) - } - pAuth, err := url.Parse(v3.Authorization) - if err != nil { - return fmt.Errorf("failed to parse API authorization endpoint: %w", err) - } - pToken, err := url.Parse(v3.Token) - if err != nil { - return fmt.Errorf("failed to parse API token endpoint: %w", err) - } - if pAPI.Scheme != "https" { - return fmt.Errorf("API Scheme: '%s', is not equal to HTTPS", pAPI.Scheme) - } - if pAPI.Scheme != pAuth.Scheme { - return fmt.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme) - } - if pAPI.Scheme != pToken.Scheme { - return fmt.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme) - } - return nil -} diff --git a/internal/api/profiles/profiles.go b/internal/api/profiles/profiles.go deleted file mode 100644 index 77109f1..0000000 --- a/internal/api/profiles/profiles.go +++ /dev/null @@ -1,119 +0,0 @@ -// Package profiles defines a wrapper around the various profiles -// returned by the /info endpoint -package profiles - -import ( - "codeberg.org/eduVPN/eduvpn-common/types/protocol" - "codeberg.org/eduVPN/eduvpn-common/types/server" -) - -// Profile is the information for a profile -type Profile struct { - // ID is the identifier of the profile - // Used to select a profile - ID string `json:"profile_id"` - // DisplayName defines the UI friendly name for the profile - DisplayName string `json:"display_name"` - // VPNProtoList defines the list of VPN protocols - // E.g. wireguard, openvpn - VPNProtoList []string `json:"vpn_proto_list"` - // VPNProtoTransportList defines the list of VPN protocols including their transport values - // E.g. wireguard+udp, openvpn+tcp - VPNProtoTransportList []string `json:"vpn_proto_transport_list"` - // DefaultGateway specifies whether or not this profile is a default gateway profile - DefaultGateway bool `json:"default_gateway"` - // DNSSearchDomains specifies the list of dns search domains - // This is provided for a Linux client issue - // See: https://github.com/eduvpn/python-eduvpn-client/issues/550 - DNSSearchDomains []string `json:"dns_search_domain_list"` - // Priority is the priority of the profile for sorting in the UI - // the higher the priority, the higher it should be in the list - Priority int `json:"profile_priority"` -} - -// ListInfo is the struct that has the profile list -type ListInfo struct { - ProfileList []Profile `json:"profile_list"` -} - -// Info is the top-level struct for the info endpoint -type Info struct { - Info ListInfo `json:"info"` -} - -// Len returns the length of the profile list -func (i Info) Len() int { - return len(i.Info.ProfileList) -} - -// Get returns a profile with id `id`, it returns nil if it is not found -func (i Info) Get(id string) *Profile { - for _, p := range i.Info.ProfileList { - if p.ID == id { - return &p - } - } - return nil -} - -// MustIndex gets a profile by index -// This index must be in the bounds -func (i Info) MustIndex(n int) Profile { - return i.Info.ProfileList[n] -} - -func hasProtocol(protos []string, proto protocol.Protocol) bool { - for _, p := range protos { - if protocol.New(p) == proto { - return true - } - } - return false -} - -// ShouldFailover returns whether or not this VPN profile should start a failover procedure -// This is true when the profile supports a TCP connection -// If we cannot determine whether it supports a TCP connection -// (because the server doesn't provide the VPN transport list function yet), -// we will just check if it supports OpenVPN -func (p *Profile) ShouldFailover() bool { - // old servers don't support it, only failover in case OpenVPN is supported - if len(p.VPNProtoTransportList) == 0 { - // this checks VPNProtoList - return p.HasOpenVPN() - } - for _, c := range p.VPNProtoTransportList { - if c == "wireguard+tcp" { - return true - } - if c == "openvpn+tcp" { - return true - } - } - return false -} - -// HasOpenVPN returns whether or not the profile has OpenVPN support -func (p *Profile) HasOpenVPN() bool { - return hasProtocol(p.VPNProtoList, protocol.OpenVPN) -} - -// HasWireGuard returns whether or not the profile has WireGuard support -func (p *Profile) HasWireGuard() bool { - return hasProtocol(p.VPNProtoList, protocol.WireGuard) -} - -// Public gets the server list as a structure that we return to clients -func (i Info) Public() server.Profiles { - m := make(map[string]server.Profile) - for _, p := range i.Info.ProfileList { - m[p.ID] = server.Profile{ - DisplayName: map[string]string{ - "en": p.DisplayName, - }, - DefaultGateway: p.DefaultGateway, - Priority: p.Priority, - } - } - return server.Profiles{Map: m} -} diff --git a/internal/api/redirect.go b/internal/api/redirect.go deleted file mode 100644 index 417edf5..0000000 --- a/internal/api/redirect.go +++ /dev/null @@ -1,28 +0,0 @@ -package api - -// customRedirects supplies redirect values that should be handled by the app itself -// here we hardcode the redirect values that we should use in the OAuth requests -// these values were taken from https://codeberg.org/eduVPN/vpn-user-portal/src/branch/v3/src/OAuth/VpnClientDb.php -var customRedirects = map[string]string{ - "org.letsconnect-vpn.app.macos": "org.letsconnect-vpn.app.macos:/api/callback", - "org.letsconnect-vpn.app.ios": "org.letsconnect-vpn.app.ios:/api/callback", - "org.letsconnect-vpn.app.android": "org.letsconnect-vpn.app.android:/api/callback", - "org.eduvpn.app.macos": "org.eduvpn.app.macos:/api/callback", - "org.eduvpn.app.ios": "org.eduvpn.app.ios:/api/callback", - "org.eduvpn.app.android": "org.eduvpn.app.android:/api/callback", - "org.govvpn.app.macos": "org.govvpn.app.macos:/api/callback", - "org.govvpn.app.ios": "org.govvpn.app.ios:/api/callback", - "org.govvpn.app.android": "org.govvpn.app.android:/api/callback", -} - -// customRedirect returns the custom redirect string for the clientID `cid` -// Empty string if none is defined or one is defined but is empty. -// In both empty string cases, eduvpn-common handles the redirects as 127.0.0.1 local server redirects -// If a non-empty string is returned, the redirect should be handled by the client and we only use the redirect URI value in our OAuth requests -func customRedirect(cid string) string { - v, ok := customRedirects[cid] - if !ok { - return "" - } - return v -} diff --git a/internal/commonver/commonver.go b/internal/commonver/commonver.go new file mode 100644 index 0000000..71ac16e --- /dev/null +++ b/internal/commonver/commonver.go @@ -0,0 +1,9 @@ +// Package commonver defines a version string for eduvpn-common that is used for: +// - building +// - the user agent +// - tagging +package commonver + +// Version is the latest version +// Update this when releasing +const Version = "4.0.0" diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 512756b..cde846c 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -11,7 +11,7 @@ import ( "sync" "time" - httpw "codeberg.org/eduVPN/eduvpn-common/internal/http" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" "codeberg.org/eduVPN/eduvpn-common/internal/levenshtein" "codeberg.org/eduVPN/eduvpn-common/internal/verify" discotypes "codeberg.org/eduVPN/eduvpn-common/types/discovery" @@ -87,7 +87,7 @@ type Discovery struct { // mu is the read write mutex that protects the struct from concurrent access mu sync.RWMutex // The httpClient for sending HTTP requests - httpClient *httpw.Client + httpClient *httpwrap.Client // Organizations represents the organizations that are returned by the discovery server OrganizationList Organizations `json:"organizations"` @@ -105,21 +105,21 @@ func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousV var newUpdate time.Time // No HTTP client present, create one if discovery.httpClient == nil { - discovery.httpClient = httpw.NewClient(nil) + discovery.httpClient = httpwrap.NewClient(nil) } // Get json data - jsonURL, err := httpw.JoinURLPath(DiscoURL, jsonFile) + jsonURL, err := httpwrap.JoinURLPath(DiscoURL, jsonFile) if err != nil { return newUpdate, err } - var opts *httpw.OptionalParams + var opts *httpwrap.OptionalParams if !last.IsZero() { header := http.Header{ "If-Modified-Since": []string{last.Format(http.TimeFormat)}, } - opts = &httpw.OptionalParams{ + opts = &httpwrap.OptionalParams{ Headers: header, } } @@ -143,7 +143,7 @@ func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousV // Get signature sigFile := jsonFile + ".minisig" - sigURL, err := httpw.JoinURLPath(DiscoURL, sigFile) + sigURL, err := httpwrap.JoinURLPath(DiscoURL, sigFile) if err != nil { return newUpdate, err } @@ -334,7 +334,7 @@ func (discovery *Discovery) Organizations(ctx context.Context, cache bool) (*Org var jsonDecode Organizations update, err := discovery.file(ctx, file, discovery.OrganizationList.Version, discovery.OrganizationList.UpdateHeader, &jsonDecode) if err != nil { - statErr := &httpw.StatusError{} + statErr := &httpwrap.StatusError{} if errors.As(err, &statErr) { if statErr.Status != 304 { slog.Warn("failed to get fresh organization", "error", err) @@ -383,7 +383,7 @@ func (discovery *Discovery) Servers(ctx context.Context, cache bool) (*Servers, var jsonDecode Servers update, err := discovery.file(ctx, file, discovery.ServerList.Version, discovery.ServerList.UpdateHeader, &jsonDecode) if err != nil { - statErr := &httpw.StatusError{} + statErr := &httpwrap.StatusError{} if errors.As(err, &statErr) { if statErr.Status != 304 { slog.Warn("failed to get fresh servers", "error", err) diff --git a/internal/eduvpnapi/cache.go b/internal/eduvpnapi/cache.go new file mode 100644 index 0000000..d3e2b77 --- /dev/null +++ b/internal/eduvpnapi/cache.go @@ -0,0 +1,67 @@ +package eduvpnapi + +import ( + "context" + "net/http" + "sync" + "time" + + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/endpoints" +) + +// EndpointCache is a struct that caches well-known API endpoints +type EndpointCache struct { + lastUpdate map[string]time.Time + lastEP map[string]*endpoints.Endpoints + mu sync.Mutex +} + +// Get returns a cached or fresh endpoint cache copy +func (ec *EndpointCache) Get(ctx context.Context, wk string, transport http.RoundTripper) (*endpoints.Endpoints, error) { + ec.mu.Lock() + defer ec.mu.Unlock() + + // get the last update time + lu := time.Time{} + if v, ok := ec.lastUpdate[wk]; ok { + lu = v + } + + // if not 10 minutes have passed, return cached copy + if !lu.IsZero() && !time.Now().After(lu.Add(10*time.Minute)) { + v, ok := ec.lastEP[wk] + if ok { + return v, nil + } + } + + // get fresh API endpoints + ep, err := getEndpoints(ctx, wk, transport) + if err != nil { + return nil, err + } + + // update endpoints + ec.lastUpdate[wk] = time.Now() + ec.lastEP[wk] = ep + + return ep, nil +} + +var ( + epCache *EndpointCache + epCacheOnce sync.Once +) + +// GetEndpointCache returns the global singleton endpoint cache +// or creates one if it does not exist +func GetEndpointCache() *EndpointCache { + epCacheOnce.Do(func() { + epCache = &EndpointCache{ + lastUpdate: make(map[string]time.Time), + lastEP: make(map[string]*endpoints.Endpoints), + } + }) + + return epCache +} diff --git a/internal/eduvpnapi/eduvpnapi.go b/internal/eduvpnapi/eduvpnapi.go new file mode 100644 index 0000000..62fe0d1 --- /dev/null +++ b/internal/eduvpnapi/eduvpnapi.go @@ -0,0 +1,395 @@ +// Package eduvpnapi implements version 3 of the eduVPN api: https://docs.eduvpn.org/server/v3/api.html +package eduvpnapi + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "net/url" + "time" + + "codeberg.org/jwijenbergh/eduoauth-go/v2" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/endpoints" + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/profiles" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" + "codeberg.org/eduVPN/eduvpn-common/internal/wireguard" + "codeberg.org/eduVPN/eduvpn-common/types/protocol" + "codeberg.org/eduVPN/eduvpn-common/types/server" +) + +// Callbacks is the API callback interface +// It is used to trigger authorization and forward token updates +type Callbacks interface { + // TriggerAuth is called when authorization should be triggered + TriggerAuth(context.Context, string, bool) (string, error) + // AuthDone is called when authorization has just completed + AuthDone(string, server.Type) + // TokensUpdates is called when tokens are updated + TokensUpdated(string, server.Type, eduoauth.Token) +} + +// ServerData is the data for a server that is passed to the API struct +type ServerData struct { + // ID is the identifier for the server + ID string + // Type is the type of server + Type server.Type + // BaseWK is the base well-known endpoint + BaseWK string + // BaseAuthWK is the base well-known endpoint for authorization. This is only different in case of secure internet + BaseAuthWK string + // ProcessAuth processes the OAuth authorization + ProcessAuth func(context.Context, string) (string, error) + // DisableAuthorize indicates whether or not new authorization requests should be disabled + DisableAuthorize bool + // transport is the HTTP transport, only used for testing currently + transport http.RoundTripper +} + +// Transport returns the transport to be used for the server +// By default it uses the transport from internal/httpwrap DefaultTransport +func (s *ServerData) Transport() http.RoundTripper { + if s.transport == nil { + return httpwrap.DefaultTransport + } + return s.transport +} + +// API is the top-level struct that each method is defined on +type API struct { + cb Callbacks + // oauth is the oauth object + oauth *eduoauth.OAuth + // Data is the server data + Data ServerData +} + +// NewAPI creates a new API object by creating an OAuth object +func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, tokens *eduoauth.Token) (*API, error) { + cr := customRedirect(clientID) + // Construct OAuth + + transp := sd.Transport() + post := true + // we do not support non-loopback clients with response_mode form_post + if cr != "" { + post = false + } + o := eduoauth.OAuth{ + ClientID: clientID, + EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) { + ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, transp) + if err != nil { + return nil, err + } + return &eduoauth.EndpointResponse{ + AuthorizationURL: ep.API.V3.Authorization, + TokenURL: ep.API.V3.Token, + }, nil + }, + CustomRedirect: cr, + FormPost: post, + RedirectPath: "/callback", + TokensUpdated: func(tok eduoauth.Token) { + cb.TokensUpdated(sd.ID, sd.Type, tok) + }, + Transport: transp, + UserAgent: httpwrap.UserAgent, + } + + if tokens != nil { + o.UpdateTokens(*tokens) + } + + api := &API{ + cb: cb, + oauth: &o, + Data: sd, + } + err := api.authorize(ctx) + if err != nil { + return nil, err + } + return api, nil +} + +// ErrAuthorizeDisabled is returned when authorization is disabled but is needed to complete +var ErrAuthorizeDisabled = errors.New("cannot authorize as re-authorization is disabled") + +func (a *API) authorize(ctx context.Context) (err error) { + _, err = a.oauth.AccessToken(ctx) + // already authorized + if err == nil { + return nil + } + + // otherwise check if invalid tokens, + // if not then something else is wrong with the API + // return an error + tErr := &eduoauth.TokensInvalidError{} + if !errors.As(err, &tErr) { + return err + } + + if a.Data.DisableAuthorize { + return ErrAuthorizeDisabled + } + + defer func() { + if err == nil { + a.cb.AuthDone(a.Data.ID, a.Data.Type) + } + }() + + scope := "config" + url, err := a.oauth.AuthURL(ctx, scope) + if err != nil { + return err + } + if a.Data.ProcessAuth != nil { + url, err = a.Data.ProcessAuth(ctx, url) + if err != nil { + return err + } + } + // We expect an uri if custom redirect is non empty + uri, err := a.cb.TriggerAuth(ctx, url, a.oauth.CustomRedirect != "") + if err != nil { + return err + } + // The uri is only given here if a custom redirect is done + err = a.oauth.Exchange(ctx, uri) + if err != nil { + return err + } + return nil +} + +func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpwrap.OptionalParams) (http.Header, []byte, error) { + ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport()) + if err != nil { + return nil, nil, err + } + u := ep.API.V3.API + endpoint + + // TODO: Cache HTTP client? + httpC := httpwrap.NewClient(a.oauth.NewHTTPClient()) + return httpC.Do(ctx, method, u, opts) +} + +func (a *API) authorizedRetry(ctx context.Context, method string, endpoint string, opts *httpwrap.OptionalParams) (http.Header, []byte, error) { + h, body, err := a.authorized(ctx, method, endpoint, opts) + if err == nil { + return h, body, nil + } + + statErr := &httpwrap.StatusError{} + // Only retry authorized if we get an HTTP 401 + // TODO: Can the OAuth client handle this instead? + if errors.As(err, &statErr) && statErr.Status == 401 { + slog.Debug("Got a HTTP 401. Marking tokens as expired...", "HTTP method", method, "endpoint", endpoint) + // Mark the token as expired and retry, so we trigger the refresh flow + a.oauth.SetTokenExpired() + h, body, err = a.authorized(ctx, method, endpoint, opts) + } + // Tokens is invalid we need to renew and authorize again + tErr := &eduoauth.TokensInvalidError{} + if err != nil && errors.As(err, &tErr) { + // Mark the token as invalid and retry, so we trigger the authorization flow + a.oauth.SetTokenRenew() + slog.Debug("The tokens were invalid, trying again...") + if autherr := a.authorize(ctx); autherr != nil { + return nil, nil, autherr + } + return a.authorized(ctx, method, endpoint, opts) + } + return h, body, err +} + +// Disconnect disconnects a client from the server by sending a /disconnect API call +// This cleans up resources such as WireGuard IP allocation +func (a *API) Disconnect(ctx context.Context) error { + _, _, err := a.authorized(ctx, http.MethodPost, "/disconnect", &httpwrap.OptionalParams{Timeout: 5 * time.Second}) + return err +} + +// Info does the /info API call +func (a *API) Info(ctx context.Context) (*profiles.Info, error) { + _, body, err := a.authorizedRetry(ctx, http.MethodGet, "/info", nil) + if err != nil { + return nil, fmt.Errorf("failed API /info: %w", err) + } + p := profiles.Info{} + if err = json.Unmarshal(body, &p); err != nil { + return nil, fmt.Errorf("failed API /info: %w", err) + } + return &p, nil +} + +// ConnectData is the data that is returned when the /connect call completes without error +type ConnectData struct { + // Configuration is the VPN configuration + Configuration string + // Protocol tells us what protocol it is, OpenVPN or WireGuard (proxied or not) + Protocol protocol.Protocol + // Expires tells us when this configuration expires + Expires time.Time +} + +// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 +func boolToYesNo(preferTCP bool) string { + if preferTCP { + return "yes" + } + return "no" +} + +func protocolFromCT(ct string) (protocol.Protocol, error) { + switch ct { + case "application/x-wireguard-profile": + return protocol.WireGuard, nil + case "application/x-wireguard+tcp-profile": + return protocol.WireGuardProxy, nil + case "application/x-openvpn-profile": + return protocol.OpenVPN, nil + } + return protocol.Unknown, fmt.Errorf("invalid content type: %s", ct) +} + +// ErrNoProtocols is returned when a connect call is given with an empty protocol slice +var ErrNoProtocols = errors.New("no protocols supplied") + +// ErrUnknownProtocol is returned when the client in a connect gives an unknown protocol +var ErrUnknownProtocol = errors.New("unknown protocol supplied") + +// Connect sends a /connect to an eduVPN server +// `ctx` is the context used for cancellation +// protos is the list of protocols supported and wanted by the client +func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []protocol.Protocol, pTCP bool) (*ConnectData, error) { + hdrs := http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + } + uv := url.Values{ + "profile_id": {prof.ID}, + } + + if len(protos) == 0 { + return nil, ErrNoProtocols + } + + var wgKey *wgtypes.Key + + // Loop over the protocols and set the correct headers and values + for _, p := range protos { + switch p { + case protocol.WireGuard: + gk, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + wgKey = &gk + // Set the public key + pubkey := wgKey.PublicKey() + uv.Set("public_key", pubkey.String()) + hdrs.Add("accept", "application/x-wireguard-profile") + hdrs.Add("accept", "application/x-wireguard+tcp-profile") + case protocol.OpenVPN: + hdrs.Add("accept", "application/x-openvpn-profile") + default: + return nil, ErrUnknownProtocol + } + } + // set prefer TCP + uv.Set("prefer_tcp", boolToYesNo(pTCP)) + + // Construct the parameters + params := &httpwrap.OptionalParams{Headers: hdrs, Body: uv} + h, body, err := a.authorizedRetry(ctx, http.MethodPost, "/connect", params) + if err != nil { + return nil, fmt.Errorf("failed API /connect call: %w", err) + } + + // Parse expiry + expH := h.Get("expires") + if expH == "" { + return nil, errors.New("the server did not give an expires header") + } + expT, err := http.ParseTime(expH) + if err != nil { + return nil, fmt.Errorf("failed parsing expiry time: %w", err) + } + + vpnCfg := string(body) + // Parse content type + contentH := h.Get("content-type") + proto, err := protocolFromCT(contentH) + if err != nil { + return nil, err + } + + if proto == protocol.OpenVPN { + // ensure scripts are not ran by default by append script-security 0 to the config + vpnCfg += "\nscript-security 0" + return &ConnectData{ + Configuration: vpnCfg, + Protocol: proto, + Expires: expT, + }, nil + } + + vpnCfg, err = wireguard.Config(vpnCfg, wgKey) + if err != nil { + return nil, err + } + return &ConnectData{ + Configuration: vpnCfg, + Protocol: proto, + Expires: expT, + }, nil +} + +func getEndpoints(ctx context.Context, url string, tp http.RoundTripper) (*endpoints.Endpoints, error) { + uStr, err := httpwrap.JoinURLPath(url, "/.well-known/vpn-user-portal") + if err != nil { + return nil, err + } + httpC := httpwrap.NewClient(nil) + httpC.Client.Transport = tp + _, body, err := httpC.Get(ctx, uStr) + if err != nil { + return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) + } + + ep := endpoints.Endpoints{} + if err = json.Unmarshal(body, &ep); err != nil { + return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) + } + err = ep.Validate() + if err != nil { + return nil, err + } + return &ep, nil +} + +// OAuthLogger is defined here to update the internal logger +// for the eduoauth library +type OAuthLogger struct{} + +// Logf logs a message with parameters +func (ol *OAuthLogger) Logf(msg string, params ...any) { + slog.Debug("OAuth log", "log", fmt.Sprintf(msg, params...)) +} + +// Log logs a message +func (ol *OAuthLogger) Log(msg string) { + slog.Debug("OAuth log", "log", msg) +} + +func init() { + eduoauth.UpdateLogger(&OAuthLogger{}) +} diff --git a/internal/eduvpnapi/eduvpnapi_test.go b/internal/eduvpnapi/eduvpnapi_test.go new file mode 100644 index 0000000..23f895b --- /dev/null +++ b/internal/eduvpnapi/eduvpnapi_test.go @@ -0,0 +1,513 @@ +package eduvpnapi + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "reflect" + "regexp" + "slices" + "strings" + "testing" + "time" + + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/profiles" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" + "codeberg.org/eduVPN/eduvpn-common/internal/test" + "codeberg.org/eduVPN/eduvpn-common/types/protocol" + "codeberg.org/eduVPN/eduvpn-common/types/server" + "codeberg.org/jwijenbergh/eduoauth-go/v2" +) + +func tokenHandler(t *testing.T, gt []string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for token handler: %v", r.Method) + } + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed reading token endpoint body: %v", err) + } + parsed, err := url.ParseQuery(string(b)) + if err != nil { + t.Fatalf("failed parsing query body: %v", err) + } + grant := parsed.Get("grant_type") + + if slices.Contains(gt, grant) { + _, err = w.Write([]byte(` +{ + "access_token": "validaccess", + "refresh_token": "validrefresh", + "expires_in": 3600 +} + `)) + if err != nil { + t.Fatalf("failed writing in token handler: %v", err) + } + return + } + t.Fatalf("grant type: %v, not allowed", grant) + } +} + +func checkAuthBearer(t *testing.T, r *http.Request) { + authh := r.Header.Get("Authorization") + if !strings.HasPrefix(authh, "Bearer ") { + t.Fatalf("API call is not given with an authorization Bearer header, got: %v", authh) + } +} + +func connectHandler(t *testing.T, proto string, exp time.Time) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for connect handler: %v", r.Method) + } + checkAuthBearer(t, r) + w.Header().Set("expires", exp.Format(http.TimeFormat)) + w.Header().Set("content-type", fmt.Sprintf("application/x-%s-profile", proto)) + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed reading token endpoint body: %v", err) + } + parsed, err := url.ParseQuery(string(b)) + if err != nil { + t.Fatalf("failed parsing query body: %v", err) + } + // the wireguard config we parse + var cfg string + if proto == "openvpn" { + cfg = "openvpnconfig" + } else { + if parsed.Get("public_key") == "" { + t.Fatalf("no public_key given") + } + if proto == "wireguard+tcp" { + ptcp := parsed.Get("prefer_tcp") + if ptcp != "yes" { + t.Fatalf("prefer TCP is not yes: %s", ptcp) + } + cfg = ` +[Interface] +[Peer] +ProxyEndpoint = https://proxyendpoint +` + } else { + cfg = "[Interface]" + } + } + _, err = w.Write([]byte(cfg)) + if err != nil { + t.Fatalf("failed writing /connect response: %v", err) + } + } +} + +func disconnectHandler(t *testing.T) func(http.ResponseWriter, *http.Request) { + return func(_ http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for disconnect handler: %v", r.Method) + } + checkAuthBearer(t, r) + } +} + +type TestCallback struct { + t *testing.T +} + +func (tc *TestCallback) TriggerAuth(_ context.Context, str string, _ bool) (string, error) { + go func() { + u, err := url.Parse(str) + if err != nil { + panic(err) + } + ru, err := url.Parse(u.Query().Get("redirect_uri")) + if err != nil { + panic(err) + } + oq := u.Query() + q := ru.Query() + q.Set("state", oq.Get("state")) + q.Set("code", "fakeauthcode") + ru.RawQuery = q.Encode() + + c := http.Client{} + req, err := http.NewRequest("GET", ru.String(), nil) + if err != nil { + panic(err) + } + _, err = c.Do(req) + if err != nil { + panic(err) + } + }() + return "", nil +} +func (tc *TestCallback) AuthDone(string, server.Type) {} +func (tc *TestCallback) TokensUpdated(string, server.Type, eduoauth.Token) {} + +// create a API struct with allowed grant types +func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.HandlerPath) (*API, *test.Server) { + // Create a simple API client and check if the fields are created correctly + listen, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to setup listener for test server: %v", err) + } + + hps = append(hps, []test.HandlerPath{ + { + Method: http.MethodGet, + Path: "/.well-known/vpn-user-portal", + Response: fmt.Sprintf(` +{ + "api": { + "http://eduvpn.org/api#3": { + "api_endpoint": "https://%[1]s/test-api-endpoint", + "authorization_endpoint": "https://%[1]s/test-authorization-endpoint", + "token_endpoint": "https://%[1]s/test-token-endpoint" + } + }, + "v": "0.0.0" +} +`, listen.Addr().String()), + }, + { + Path: "/test-token-endpoint", + ResponseHandler: tokenHandler(t, gt), + }, + }...) + // start server + serv := test.NewServerWithHandles(hps, listen) + servc, err := serv.Client() + if err != nil { + t.Fatalf("failed to setup HTTP test server client: %v", servc) + } + + sd := ServerData{ + ID: "randomidentifier", + Type: server.TypeCustom, + BaseWK: serv.URL, + BaseAuthWK: serv.URL, + ProcessAuth: func(_ context.Context, in string) (string, error) { + return in, nil + }, + DisableAuthorize: false, + transport: servc.Client.Transport, + } + + tc := &TestCallback{t: t} + + a, err := NewAPI(context.Background(), "testclient", sd, tc, tok) + if err != nil { + t.Fatalf("failed creating API: %v", err) + } + return a, serv +} + +func TestNewAPI(t *testing.T) { + gts := []string{"refresh_token"} + tok := &eduoauth.Token{ + Access: "expiredaccess", + Refresh: "expiredrefresh", + // tokens are expired, let's try authorizing + ExpiredTimestamp: time.Now(), + } + a, srv := createTestAPI(t, tok, gts, nil) + srv.Close() + + // now the tokens should be the new access tokens + if a.oauth.Token().Access != "validaccess" { + t.Fatalf("access token is not valid access") + } + if a.oauth.Token().Refresh != "validrefresh" { + t.Fatalf("refresh token is not valid refresh") + } + + gts = []string{"authorization_code"} + tok = &eduoauth.Token{ + Access: "expiredaccess", + Refresh: "", + ExpiredTimestamp: time.Now(), + } + a, srv = createTestAPI(t, tok, gts, nil) + srv.Close() + + // now the tokens should be the new access tokens + if a.oauth.Token().Access != "validaccess" { + t.Fatalf("access token is not valid access") + } + if a.oauth.Token().Refresh != "validrefresh" { + t.Fatalf("refresh token is not valid refresh") + } +} + +func TestAPIInfo(t *testing.T) { + // auth should not be triggered + var gts []string + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + statErr := &httpwrap.StatusError{} + cases := []struct { + hp test.HandlerPath + info *profiles.Info + err any + }{ + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: ` +{ + "info": { + "profile_list": [ + { + "default_gateway": false, + "display_name": "test profile 1", + "profile_id": "test1", + "profile_priority": 3, + "vpn_proto_list": [ + "openvpn", + "wireguard" + ] + } + ] + } +} +`, + }, + info: &profiles.Info{ + Info: profiles.ListInfo{ + ProfileList: []profiles.Profile{ + { + ID: "test1", + DisplayName: "test profile 1", + VPNProtoList: []string{"openvpn", "wireguard"}, + Priority: 3, + DefaultGateway: false, + }, + }, + }, + }, + }, + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: ` +{ + "info": { + "profile_list": [ + { + "display_name": "test profile 2", + "profile_id": "test2", + "vpn_proto_list": [ + "wireguard" + ] + } + ] + } +} +`, + }, + info: &profiles.Info{ + Info: profiles.ListInfo{ + ProfileList: []profiles.Profile{ + { + ID: "test2", + DisplayName: "test profile 2", + VPNProtoList: []string{"wireguard"}, + DefaultGateway: false, + }, + }, + }, + }, + }, + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: "", + ResponseCode: 404, + }, + info: nil, + err: &statErr, + }, + } + + for _, c := range cases { + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp}) + defer srv.Close() + gprfs, err := a.Info(context.Background()) + // got error but the want error is nil + if err != nil { + if c.err == nil { + t.Fatalf("failed profiles info: %v but want no error", err) + } + + if !errors.As(err, c.err) { + t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err.Error()) + } + } else if c.err != nil { + t.Fatalf("got no error but want error: %T", c.err) + } + + if !reflect.DeepEqual(gprfs, c.info) { + t.Fatalf("got info: %v, not equal to want: %v", gprfs, c.info) + } + } +} + +func TestAPIConnect(t *testing.T) { + // auth should not be triggered + var gts []string + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + cases := []struct { + hp test.HandlerPath + cd *ConnectData + prof profiles.Profile + protos []protocol.Protocol + ptcp bool + err error + }{ + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + }, + cd: nil, + err: ErrNoProtocols, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + }, + cd: nil, + protos: []protocol.Protocol{protocol.Unknown}, + err: ErrUnknownProtocol, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + }, + cd: nil, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard, protocol.Unknown}, + err: ErrUnknownProtocol, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "openvpn", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: "openvpnconfig\nscript-security 0", + Protocol: protocol.OpenVPN, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + }, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "wireguard", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: `\[Interface\] +PrivateKey = .*`, + Protocol: protocol.WireGuard, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + }, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "wireguard+tcp", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: `\[Interface\] +PrivateKey = .*`, + Protocol: protocol.WireGuardProxy, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + }, + ptcp: true, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + } + + for _, c := range cases { + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp}) + defer srv.Close() + gcd, err := a.Connect(context.Background(), c.prof, c.protos, c.ptcp) + // got error but the want error is nil + if err != nil { + if c.err == nil { + t.Fatalf("failed connect: %v but want no error", err) + } + + if !errors.Is(err, c.err) { + t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err) + } + } else if c.err != nil { + t.Fatalf("got no error but want error: %T", c.err) + } + + if gcd != nil && c.cd != nil { + m, err := regexp.MatchString(c.cd.Configuration, gcd.Configuration) + if err != nil { + t.Fatalf("failed matching regexp: %v", err) + } + if !m { + t.Fatalf("regex:\n%s\ndoes not match config:\n%s", c.cd.Configuration, gcd.Configuration) + } + // we have already checked the config using a regex + c.cd.Configuration = gcd.Configuration + + } + if !reflect.DeepEqual(gcd, c.cd) { + t.Fatalf("got connect data: %v, not equal to want: %v", gcd, c.cd) + } + } +} + +func TestDisconnect(t *testing.T) { + var gts []string + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{ + { + Path: "/test-api-endpoint/disconnect", + ResponseHandler: disconnectHandler(t), + }, + }) + defer srv.Close() + err := a.Disconnect(context.Background()) + if err != nil { + t.Fatalf("failed /disconnect: %v", err) + } +} diff --git a/internal/eduvpnapi/endpoints/endpoints.go b/internal/eduvpnapi/endpoints/endpoints.go new file mode 100644 index 0000000..c98d2c7 --- /dev/null +++ b/internal/eduvpnapi/endpoints/endpoints.go @@ -0,0 +1,62 @@ +// Package endpoints defines a wrapper around the various +// endpoints returned by an eduVPN server in well-known +package endpoints + +import ( + "fmt" + "net/url" +) + +// List is the list of endpoints as returned by the eduVPN server +type List struct { + // API is the API endpoint which we use for calls such as /info, /connect, ... + API string `json:"api_endpoint"` + // Authorization is the authorization endpoint for OAuth + Authorization string `json:"authorization_endpoint"` + // Token is the token endpoint for OAuth + Token string `json:"token_endpoint"` +} + +// Versions is the endpoints separated by API version +type Versions struct { + // V2 is the legacy V2 API, this is not used + V2 List `json:"http://eduvpn.org/api#2"` + // V3 is the newest API, which we use + V3 List `json:"http://eduvpn.org/api#3"` +} + +// Endpoints defines the json format for /.well-known/vpn-user-portal". +type Endpoints struct { + // API defines the API endpoints, split by version + API Versions `json:"api"` + // V is the version string for the server + V string `json:"v"` +} + +// Validate validates the endpoints by parsing them and checking the scheme is HTTP +// An error is returned if they are not valid +func (e Endpoints) Validate() error { + v3 := e.API.V3 + pAPI, err := url.Parse(v3.API) + if err != nil { + return fmt.Errorf("failed to parse API endpoint: %w", err) + } + pAuth, err := url.Parse(v3.Authorization) + if err != nil { + return fmt.Errorf("failed to parse API authorization endpoint: %w", err) + } + pToken, err := url.Parse(v3.Token) + if err != nil { + return fmt.Errorf("failed to parse API token endpoint: %w", err) + } + if pAPI.Scheme != "https" { + return fmt.Errorf("API Scheme: '%s', is not equal to HTTPS", pAPI.Scheme) + } + if pAPI.Scheme != pAuth.Scheme { + return fmt.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme) + } + if pAPI.Scheme != pToken.Scheme { + return fmt.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme) + } + return nil +} diff --git a/internal/eduvpnapi/profiles/profiles.go b/internal/eduvpnapi/profiles/profiles.go new file mode 100644 index 0000000..77109f1 --- /dev/null +++ b/internal/eduvpnapi/profiles/profiles.go @@ -0,0 +1,119 @@ +// Package profiles defines a wrapper around the various profiles +// returned by the /info endpoint +package profiles + +import ( + "codeberg.org/eduVPN/eduvpn-common/types/protocol" + "codeberg.org/eduVPN/eduvpn-common/types/server" +) + +// Profile is the information for a profile +type Profile struct { + // ID is the identifier of the profile + // Used to select a profile + ID string `json:"profile_id"` + // DisplayName defines the UI friendly name for the profile + DisplayName string `json:"display_name"` + // VPNProtoList defines the list of VPN protocols + // E.g. wireguard, openvpn + VPNProtoList []string `json:"vpn_proto_list"` + // VPNProtoTransportList defines the list of VPN protocols including their transport values + // E.g. wireguard+udp, openvpn+tcp + VPNProtoTransportList []string `json:"vpn_proto_transport_list"` + // DefaultGateway specifies whether or not this profile is a default gateway profile + DefaultGateway bool `json:"default_gateway"` + // DNSSearchDomains specifies the list of dns search domains + // This is provided for a Linux client issue + // See: https://github.com/eduvpn/python-eduvpn-client/issues/550 + DNSSearchDomains []string `json:"dns_search_domain_list"` + // Priority is the priority of the profile for sorting in the UI + // the higher the priority, the higher it should be in the list + Priority int `json:"profile_priority"` +} + +// ListInfo is the struct that has the profile list +type ListInfo struct { + ProfileList []Profile `json:"profile_list"` +} + +// Info is the top-level struct for the info endpoint +type Info struct { + Info ListInfo `json:"info"` +} + +// Len returns the length of the profile list +func (i Info) Len() int { + return len(i.Info.ProfileList) +} + +// Get returns a profile with id `id`, it returns nil if it is not found +func (i Info) Get(id string) *Profile { + for _, p := range i.Info.ProfileList { + if p.ID == id { + return &p + } + } + return nil +} + +// MustIndex gets a profile by index +// This index must be in the bounds +func (i Info) MustIndex(n int) Profile { + return i.Info.ProfileList[n] +} + +func hasProtocol(protos []string, proto protocol.Protocol) bool { + for _, p := range protos { + if protocol.New(p) == proto { + return true + } + } + return false +} + +// ShouldFailover returns whether or not this VPN profile should start a failover procedure +// This is true when the profile supports a TCP connection +// If we cannot determine whether it supports a TCP connection +// (because the server doesn't provide the VPN transport list function yet), +// we will just check if it supports OpenVPN +func (p *Profile) ShouldFailover() bool { + // old servers don't support it, only failover in case OpenVPN is supported + if len(p.VPNProtoTransportList) == 0 { + // this checks VPNProtoList + return p.HasOpenVPN() + } + for _, c := range p.VPNProtoTransportList { + if c == "wireguard+tcp" { + return true + } + if c == "openvpn+tcp" { + return true + } + } + return false +} + +// HasOpenVPN returns whether or not the profile has OpenVPN support +func (p *Profile) HasOpenVPN() bool { + return hasProtocol(p.VPNProtoList, protocol.OpenVPN) +} + +// HasWireGuard returns whether or not the profile has WireGuard support +func (p *Profile) HasWireGuard() bool { + return hasProtocol(p.VPNProtoList, protocol.WireGuard) +} + +// Public gets the server list as a structure that we return to clients +func (i Info) Public() server.Profiles { + m := make(map[string]server.Profile) + for _, p := range i.Info.ProfileList { + m[p.ID] = server.Profile{ + DisplayName: map[string]string{ + "en": p.DisplayName, + }, + DefaultGateway: p.DefaultGateway, + Priority: p.Priority, + } + } + return server.Profiles{Map: m} +} diff --git a/internal/eduvpnapi/redirect.go b/internal/eduvpnapi/redirect.go new file mode 100644 index 0000000..7af31fb --- /dev/null +++ b/internal/eduvpnapi/redirect.go @@ -0,0 +1,28 @@ +package eduvpnapi + +// customRedirects supplies redirect values that should be handled by the app itself +// here we hardcode the redirect values that we should use in the OAuth requests +// these values were taken from https://codeberg.org/eduVPN/vpn-user-portal/src/branch/v3/src/OAuth/VpnClientDb.php +var customRedirects = map[string]string{ + "org.letsconnect-vpn.app.macos": "org.letsconnect-vpn.app.macos:/api/callback", + "org.letsconnect-vpn.app.ios": "org.letsconnect-vpn.app.ios:/api/callback", + "org.letsconnect-vpn.app.android": "org.letsconnect-vpn.app.android:/api/callback", + "org.eduvpn.app.macos": "org.eduvpn.app.macos:/api/callback", + "org.eduvpn.app.ios": "org.eduvpn.app.ios:/api/callback", + "org.eduvpn.app.android": "org.eduvpn.app.android:/api/callback", + "org.govvpn.app.macos": "org.govvpn.app.macos:/api/callback", + "org.govvpn.app.ios": "org.govvpn.app.ios:/api/callback", + "org.govvpn.app.android": "org.govvpn.app.android:/api/callback", +} + +// customRedirect returns the custom redirect string for the clientID `cid` +// Empty string if none is defined or one is defined but is empty. +// In both empty string cases, eduvpn-common handles the redirects as 127.0.0.1 local server redirects +// If a non-empty string is returned, the redirect should be handled by the client and we only use the redirect URI value in our OAuth requests +func customRedirect(cid string) string { + v, ok := customRedirects[cid] + if !ok { + return "" + } + return v +} diff --git a/internal/http/http.go b/internal/http/http.go deleted file mode 100644 index 7b9b70d..0000000 --- a/internal/http/http.go +++ /dev/null @@ -1,288 +0,0 @@ -// Package http defines higher level helpers for the net/http package -package http - -import ( - "context" - "crypto/tls" - "errors" - "fmt" - "io" - "log/slog" - "net/http" - "net/url" - "path" - "strings" - "time" - - "codeberg.org/eduVPN/eduvpn-common/internal/version" -) - -// UserAgent is the user agent that is used for requests -var UserAgent string - -// URLParameters is a type used for the parameters in the URL. -type URLParameters map[string]string - -// OptionalParams is a structure that defines the optional parameters that are given when making a HTTP call. -type OptionalParams struct { - Headers http.Header - URLParameters URLParameters - Body url.Values - Timeout time.Duration -} - -func cleanPath(u *url.URL, trailing bool) string { - if u.Path != "" { - // Clean the path - // https://pkg.go.dev/path#Clean - u.Path = path.Clean(u.Path) - } - - str := u.String() - - // Make sure the URL ends with a / - if trailing && str[len(str)-1:] != "/" { - str += "/" - } - return str -} - -// EnsureValidURL ensures that the input URL is valid to be used internally -// It does the following -// - Sets the scheme to https if none is given -// - It 'cleans' up the path using path.Clean -// - It makes sure that the URL ends with a / -// It returns an error if the URL cannot be parsed. -func EnsureValidURL(s string, trailing bool) (string, error) { - u, err := url.Parse(s) - if err != nil { - return "", fmt.Errorf("failed parsing url with error: %w", err) - } - - // Make sure the scheme is always https - if u.Scheme != "https" { - u.Scheme = "https" - } - return cleanPath(u, trailing), nil -} - -// JoinURLPath joins url's path, in go 1.19 we can use url.JoinPath -func JoinURLPath(u string, p string) (string, error) { - pu, err := url.Parse(u) - if err != nil { - return "", fmt.Errorf("failed to parse url for joining paths with error: %w", err) - } - pp, err := url.Parse(p) - if err != nil { - return "", fmt.Errorf("failed to parse path for joining paths with error: %w", err) - } - fp := pu.ResolveReference(pp) - - // We also clean the path for consistency - return cleanPath(fp, false), nil -} - -// ConstructURL creates a URL with the included parameters. -func ConstructURL(u *url.URL, params URLParameters) (string, error) { - q := u.Query() - - for p, value := range params { - q.Set(p, value) - } - u.RawQuery = q.Encode() - return u.String(), nil -} - -// optionalURL ensures that the URL contains the optional parameters -// it returns the url (with parameters if success) and an error indicating success. -func optionalURL(urlStr string, opts *OptionalParams) (string, error) { - u, err := url.Parse(urlStr) - if err != nil { - return "", fmt.Errorf("failed to construct parse url '%s' with error: %w", urlStr, err) - } - // Make sure the scheme is always set to HTTPS - if u.Scheme != "https" { - u.Scheme = "https" - } - - if opts == nil { - return u.String(), nil - } - - return ConstructURL(u, opts.URLParameters) -} - -// optionalHeaders ensures that the HTTP request uses the optional headers if defined. -func optionalHeaders(req *http.Request, opts *OptionalParams) { - // Add headers - if opts != nil && req != nil && opts.Headers != nil { - for k, v := range opts.Headers { - for _, cv := range v { - req.Header.Add(k, cv) - } - } - } -} - -// optionalBodyReader returns a HTTP body reader if there is a body, otherwise nil. -func optionalBodyReader(opts *OptionalParams) io.Reader { - if opts != nil && opts.Body != nil { - return strings.NewReader(opts.Body.Encode()) - } - return nil -} - -// Client is a wrapper around http.Client with some convenience features -// - A default timeout of 5 seconds -// - A read limiter to prevent servers from sending large amounts of data -// - Checking on http code with custom errors -type Client struct { - // Client is the HTTP Client that sends the request - Client *http.Client - // ReadLimit denotes the maximum amount of bytes that are read in HTTP responses - // This is used to prevent servers from sending huge amounts of data - // A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems - ReadLimit int64 - - // Timeout denotes the default timeout for each request - Timeout time.Duration -} - -// tls13Transport returns a http.Transport with the minimum TLS version set to 1.3 -func tls13Transport() *http.Transport { - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS13} - return tr -} - -// DefaultTransport is the default HTTP transport to use -// by default it is a transport that only allows TLS 1.3 -var DefaultTransport = tls13Transport() - -// NewClient returns a HTTP client with some default settings -func NewClient(client *http.Client) *Client { - c := client - if c == nil { - c = &http.Client{ - Transport: DefaultTransport, - } - } - // if a client is non-nil it uses its own transport - // for the OAuth client we also make sure TLS 1.3 is set - // TODO: Should we double verify that MinVersion is 1.3 or is that overkill? - - // ReadLimit denotes the maximum amount of bytes that are read in HTTP responses - // This is used to prevent servers from sending huge amounts of data - // A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems - // The timeout is 10 seconds by default. We pass it here and not in the http client because we want to do it per request - return &Client{Client: c, ReadLimit: 16 << 20, Timeout: 10 * time.Second} -} - -// Get creates a Get request and returns the headers, body and an error. -func (c *Client) Get(ctx context.Context, url string) (http.Header, []byte, error) { - return c.Do(ctx, http.MethodGet, url, nil) -} - -// Do sends a HTTP request using a method (e.g. GET, POST), an url and optional parameters -// It returns the HTTP headers, the body and an error if there is one. -func (c *Client) Do(ctx context.Context, method string, urlStr string, opts *OptionalParams) (http.Header, []byte, error) { - // Make sure the url contains all the parameters - // This can return an error, - // it already has the right error, so we don't wrap it further - urlStr, err := optionalURL(urlStr, opts) - if err != nil { - // No further type wrapping is needed here - return nil, nil, err - } - - // The timeout is configurable for each request - timeout := c.Timeout - if opts != nil && opts.Timeout.Seconds() > 0 { - timeout = opts.Timeout - } - - ctx, cncl := context.WithTimeout(ctx, timeout) - defer cncl() - - slog.Debug("sending request", "method", method, "url", urlStr) - - // Create request object with the body reader generated from the optional arguments - req, err := http.NewRequestWithContext(ctx, method, urlStr, optionalBodyReader(opts)) - if err != nil { - return nil, nil, fmt.Errorf("failed HTTP request with method: '%s', url: '%s' and error: %w", method, urlStr, err) - } - if UserAgent != "" { - req.Header.Add("User-Agent", UserAgent) - } - - // Make sure the headers contain all the parameters - optionalHeaders(req, opts) - - // Do request - res, err := c.Client.Do(req) - if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - return nil, nil, &TimeoutError{URL: urlStr, Method: method} - } - return nil, nil, fmt.Errorf("failed HTTP request with method: '%s', url: '%s' and error: %w", method, urlStr, err) - } - - // Request successful, make sure body is closed at the end - defer func() { - _ = res.Body.Close() - }() - - // Return a string - // A max bytes reader is normally used for request bodies with a writer - // However, this is still nice to use because unlike a limitreader, it returns an error if the body is too large - // We use this function without a writer so we pass nil - // We impose a limit because servers could be malicious and send huge amounts of data - r := http.MaxBytesReader(nil, res.Body, c.ReadLimit) - body, err := io.ReadAll(r) - if err != nil { - return res.Header, nil, fmt.Errorf("failed HTTP request with method: '%s', url: '%s', max bytes size: '%v' and error: %w", method, urlStr, c.ReadLimit, err) - } - if res.StatusCode < 200 || res.StatusCode > 299 { - return res.Header, body, fmt.Errorf("failed HTTP request with method: '%s' due to a status error: %w", method, &StatusError{URL: urlStr, Body: string(body), Status: res.StatusCode}) - } - - // Return the body in bytes and signal the status error if there was one - return res.Header, body, nil -} - -// TimeoutError indicates that we have gotten a timeout -type TimeoutError struct { - URL string - Method string -} - -// Error returns the TimeoutError as an error string. -func (e *TimeoutError) Error() string { - return fmt.Sprintf( - "timeout in obtaining HTTP resource: '%s' with method: '%s'", - e.URL, - e.Method, - ) -} - -// StatusError indicates that we have received a HTTP status error. -type StatusError struct { - URL string - Body string - Status int -} - -// Error returns the StatusError as an error string. -func (e *StatusError) Error() string { - return fmt.Sprintf( - "failed obtaining HTTP resource: '%s' as it gave an unsuccessful status code: '%d'. Body: '%s'", - e.URL, - e.Status, - e.Body, - ) -} - -// RegisterAgent registers the user agent for client and version -func RegisterAgent(client string, verApp string) { - UserAgent = fmt.Sprintf("%s/%s eduvpn-common/%s", client, verApp, version.Version) -} diff --git a/internal/http/http_test.go b/internal/http/http_test.go deleted file mode 100644 index 8c2ae0f..0000000 --- a/internal/http/http_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package http - -import ( - "testing" -) - -func TestEnsureValidURL(t *testing.T) { - _, validErr := EnsureValidURL("%notvalid%", true) - - if validErr == nil { - t.Fatal("Got nil error, want: non-nil") - } - - testCases := map[string]string{ - // Make sure we set https - "example.com/": "https://example.com/", - // Make sure we do override the scheme to https - "http://example.com/": "https://example.com/", - // This URL is already valid - "https://example.com/": "https://example.com/", - // Make sure to add a trailing slash (/) - "https://example.com": "https://example.com/", - // Cleanup the path 1 - "https://example.com/////": "https://example.com/", - // Cleanup the path 2 - "https://example.com/..": "https://example.com/", - } - - for k, v := range testCases { - valid, validErr := EnsureValidURL(k, true) - if validErr != nil { - t.Fatalf("Got: %v, want: nil", validErr) - } - if valid != v { - t.Fatalf("Got: %v, want: %v", valid, v) - } - } -} - -func Test_JoinURLPath(t *testing.T) { - cases := []struct { - u string - p string - want string - }{ - {u: "https://example.com", p: "test", want: "https://example.com/test"}, - {u: "https://example.com", p: "/test", want: "https://example.com/test"}, - {u: "https://example.com", p: "../test", want: "https://example.com/test"}, - {u: "https://example.com", p: "../test/", want: "https://example.com/test"}, - {u: "https://example.com", p: "test/", want: "https://example.com/test"}, - } - for _, c := range cases { - got, err := JoinURLPath(c.u, c.p) - if err != nil { - t.Fatalf("Failed to parse join url case: %v, err: %v", c, err) - } - if got != c.want { - t.Fatalf("Failed test case for joining URL, want: %v, got: %v", c.want, got) - } - } -} diff --git a/internal/httpwrap/httpwrap.go b/internal/httpwrap/httpwrap.go new file mode 100644 index 0000000..5fd42c8 --- /dev/null +++ b/internal/httpwrap/httpwrap.go @@ -0,0 +1,288 @@ +// Package httpwrap defines higher level helpers for the net/http package +package httpwrap + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "path" + "strings" + "time" + + "codeberg.org/eduVPN/eduvpn-common/internal/commonver" +) + +// UserAgent is the user agent that is used for requests +var UserAgent string + +// URLParameters is a type used for the parameters in the URL. +type URLParameters map[string]string + +// OptionalParams is a structure that defines the optional parameters that are given when making a HTTP call. +type OptionalParams struct { + Headers http.Header + URLParameters URLParameters + Body url.Values + Timeout time.Duration +} + +func cleanPath(u *url.URL, trailing bool) string { + if u.Path != "" { + // Clean the path + // https://pkg.go.dev/path#Clean + u.Path = path.Clean(u.Path) + } + + str := u.String() + + // Make sure the URL ends with a / + if trailing && str[len(str)-1:] != "/" { + str += "/" + } + return str +} + +// EnsureValidURL ensures that the input URL is valid to be used internally +// It does the following +// - Sets the scheme to https if none is given +// - It 'cleans' up the path using path.Clean +// - It makes sure that the URL ends with a / +// It returns an error if the URL cannot be parsed. +func EnsureValidURL(s string, trailing bool) (string, error) { + u, err := url.Parse(s) + if err != nil { + return "", fmt.Errorf("failed parsing url with error: %w", err) + } + + // Make sure the scheme is always https + if u.Scheme != "https" { + u.Scheme = "https" + } + return cleanPath(u, trailing), nil +} + +// JoinURLPath joins url's path, in go 1.19 we can use url.JoinPath +func JoinURLPath(u string, p string) (string, error) { + pu, err := url.Parse(u) + if err != nil { + return "", fmt.Errorf("failed to parse url for joining paths with error: %w", err) + } + pp, err := url.Parse(p) + if err != nil { + return "", fmt.Errorf("failed to parse path for joining paths with error: %w", err) + } + fp := pu.ResolveReference(pp) + + // We also clean the path for consistency + return cleanPath(fp, false), nil +} + +// ConstructURL creates a URL with the included parameters. +func ConstructURL(u *url.URL, params URLParameters) (string, error) { + q := u.Query() + + for p, value := range params { + q.Set(p, value) + } + u.RawQuery = q.Encode() + return u.String(), nil +} + +// optionalURL ensures that the URL contains the optional parameters +// it returns the url (with parameters if success) and an error indicating success. +func optionalURL(urlStr string, opts *OptionalParams) (string, error) { + u, err := url.Parse(urlStr) + if err != nil { + return "", fmt.Errorf("failed to construct parse url '%s' with error: %w", urlStr, err) + } + // Make sure the scheme is always set to HTTPS + if u.Scheme != "https" { + u.Scheme = "https" + } + + if opts == nil { + return u.String(), nil + } + + return ConstructURL(u, opts.URLParameters) +} + +// optionalHeaders ensures that the HTTP request uses the optional headers if defined. +func optionalHeaders(req *http.Request, opts *OptionalParams) { + // Add headers + if opts != nil && req != nil && opts.Headers != nil { + for k, v := range opts.Headers { + for _, cv := range v { + req.Header.Add(k, cv) + } + } + } +} + +// optionalBodyReader returns a HTTP body reader if there is a body, otherwise nil. +func optionalBodyReader(opts *OptionalParams) io.Reader { + if opts != nil && opts.Body != nil { + return strings.NewReader(opts.Body.Encode()) + } + return nil +} + +// Client is a wrapper around http.Client with some convenience features +// - A default timeout of 5 seconds +// - A read limiter to prevent servers from sending large amounts of data +// - Checking on http code with custom errors +type Client struct { + // Client is the HTTP Client that sends the request + Client *http.Client + // ReadLimit denotes the maximum amount of bytes that are read in HTTP responses + // This is used to prevent servers from sending huge amounts of data + // A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems + ReadLimit int64 + + // Timeout denotes the default timeout for each request + Timeout time.Duration +} + +// tls13Transport returns a http.Transport with the minimum TLS version set to 1.3 +func tls13Transport() *http.Transport { + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS13} + return tr +} + +// DefaultTransport is the default HTTP transport to use +// by default it is a transport that only allows TLS 1.3 +var DefaultTransport = tls13Transport() + +// NewClient returns a HTTP client with some default settings +func NewClient(client *http.Client) *Client { + c := client + if c == nil { + c = &http.Client{ + Transport: DefaultTransport, + } + } + // if a client is non-nil it uses its own transport + // for the OAuth client we also make sure TLS 1.3 is set + // TODO: Should we double verify that MinVersion is 1.3 or is that overkill? + + // ReadLimit denotes the maximum amount of bytes that are read in HTTP responses + // This is used to prevent servers from sending huge amounts of data + // A limit of 16MB, although maybe much larger than needed, ensures that we do not run into problems + // The timeout is 10 seconds by default. We pass it here and not in the http client because we want to do it per request + return &Client{Client: c, ReadLimit: 16 << 20, Timeout: 10 * time.Second} +} + +// Get creates a Get request and returns the headers, body and an error. +func (c *Client) Get(ctx context.Context, url string) (http.Header, []byte, error) { + return c.Do(ctx, http.MethodGet, url, nil) +} + +// Do sends a HTTP request using a method (e.g. GET, POST), an url and optional parameters +// It returns the HTTP headers, the body and an error if there is one. +func (c *Client) Do(ctx context.Context, method string, urlStr string, opts *OptionalParams) (http.Header, []byte, error) { + // Make sure the url contains all the parameters + // This can return an error, + // it already has the right error, so we don't wrap it further + urlStr, err := optionalURL(urlStr, opts) + if err != nil { + // No further type wrapping is needed here + return nil, nil, err + } + + // The timeout is configurable for each request + timeout := c.Timeout + if opts != nil && opts.Timeout.Seconds() > 0 { + timeout = opts.Timeout + } + + ctx, cncl := context.WithTimeout(ctx, timeout) + defer cncl() + + slog.Debug("sending request", "method", method, "url", urlStr) + + // Create request object with the body reader generated from the optional arguments + req, err := http.NewRequestWithContext(ctx, method, urlStr, optionalBodyReader(opts)) + if err != nil { + return nil, nil, fmt.Errorf("failed HTTP request with method: '%s', url: '%s' and error: %w", method, urlStr, err) + } + if UserAgent != "" { + req.Header.Add("User-Agent", UserAgent) + } + + // Make sure the headers contain all the parameters + optionalHeaders(req, opts) + + // Do request + res, err := c.Client.Do(req) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return nil, nil, &TimeoutError{URL: urlStr, Method: method} + } + return nil, nil, fmt.Errorf("failed HTTP request with method: '%s', url: '%s' and error: %w", method, urlStr, err) + } + + // Request successful, make sure body is closed at the end + defer func() { + _ = res.Body.Close() + }() + + // Return a string + // A max bytes reader is normally used for request bodies with a writer + // However, this is still nice to use because unlike a limitreader, it returns an error if the body is too large + // We use this function without a writer so we pass nil + // We impose a limit because servers could be malicious and send huge amounts of data + r := http.MaxBytesReader(nil, res.Body, c.ReadLimit) + body, err := io.ReadAll(r) + if err != nil { + return res.Header, nil, fmt.Errorf("failed HTTP request with method: '%s', url: '%s', max bytes size: '%v' and error: %w", method, urlStr, c.ReadLimit, err) + } + if res.StatusCode < 200 || res.StatusCode > 299 { + return res.Header, body, fmt.Errorf("failed HTTP request with method: '%s' due to a status error: %w", method, &StatusError{URL: urlStr, Body: string(body), Status: res.StatusCode}) + } + + // Return the body in bytes and signal the status error if there was one + return res.Header, body, nil +} + +// TimeoutError indicates that we have gotten a timeout +type TimeoutError struct { + URL string + Method string +} + +// Error returns the TimeoutError as an error string. +func (e *TimeoutError) Error() string { + return fmt.Sprintf( + "timeout in obtaining HTTP resource: '%s' with method: '%s'", + e.URL, + e.Method, + ) +} + +// StatusError indicates that we have received a HTTP status error. +type StatusError struct { + URL string + Body string + Status int +} + +// Error returns the StatusError as an error string. +func (e *StatusError) Error() string { + return fmt.Sprintf( + "failed obtaining HTTP resource: '%s' as it gave an unsuccessful status code: '%d'. Body: '%s'", + e.URL, + e.Status, + e.Body, + ) +} + +// RegisterAgent registers the user agent for client and version +func RegisterAgent(client string, verApp string) { + UserAgent = fmt.Sprintf("%s/%s eduvpn-common/%s", client, verApp, commonver.Version) +} diff --git a/internal/httpwrap/httpwrap_test.go b/internal/httpwrap/httpwrap_test.go new file mode 100644 index 0000000..422ee3f --- /dev/null +++ b/internal/httpwrap/httpwrap_test.go @@ -0,0 +1,61 @@ +package httpwrap + +import ( + "testing" +) + +func TestEnsureValidURL(t *testing.T) { + _, validErr := EnsureValidURL("%notvalid%", true) + + if validErr == nil { + t.Fatal("Got nil error, want: non-nil") + } + + testCases := map[string]string{ + // Make sure we set https + "example.com/": "https://example.com/", + // Make sure we do override the scheme to https + "http://example.com/": "https://example.com/", + // This URL is already valid + "https://example.com/": "https://example.com/", + // Make sure to add a trailing slash (/) + "https://example.com": "https://example.com/", + // Cleanup the path 1 + "https://example.com/////": "https://example.com/", + // Cleanup the path 2 + "https://example.com/..": "https://example.com/", + } + + for k, v := range testCases { + valid, validErr := EnsureValidURL(k, true) + if validErr != nil { + t.Fatalf("Got: %v, want: nil", validErr) + } + if valid != v { + t.Fatalf("Got: %v, want: %v", valid, v) + } + } +} + +func Test_JoinURLPath(t *testing.T) { + cases := []struct { + u string + p string + want string + }{ + {u: "https://example.com", p: "test", want: "https://example.com/test"}, + {u: "https://example.com", p: "/test", want: "https://example.com/test"}, + {u: "https://example.com", p: "../test", want: "https://example.com/test"}, + {u: "https://example.com", p: "../test/", want: "https://example.com/test"}, + {u: "https://example.com", p: "test/", want: "https://example.com/test"}, + } + for _, c := range cases { + got, err := JoinURLPath(c.u, c.p) + if err != nil { + t.Fatalf("Failed to parse join url case: %v, err: %v", c, err) + } + if got != c.want { + t.Fatalf("Failed test case for joining URL, want: %v, got: %v", c.want, got) + } + } +} diff --git a/internal/log/log.go b/internal/log/log.go deleted file mode 100644 index 91eaed8..0000000 --- a/internal/log/log.go +++ /dev/null @@ -1,37 +0,0 @@ -// Package log implements a basic level based logger -package log - -import ( - "fmt" - "io" - "log/slog" - "os" - "path" -) - -type Logger struct { - fr *FileRotater -} - -func (l *Logger) Init(dir string) (*slog.Logger, error) { - err := os.MkdirAll(dir, 0o700) - if err != nil { - return nil, err - } - name := path.Join(dir, "log") - - fr, err := NewFileRotater(name) - if err != nil { - return nil, fmt.Errorf("failed creating log rotater: %w", err) - } - l.fr = fr - multi := io.MultiWriter(os.Stdout, fr) - handler := slog.NewTextHandler(multi, &slog.HandlerOptions{ - Level: slog.LevelDebug, - }) - return slog.New(handler), nil -} - -func (l *Logger) Close() error { - return l.fr.Close() -} diff --git a/internal/log/rotate.go b/internal/log/rotate.go deleted file mode 100644 index 2971f70..0000000 --- a/internal/log/rotate.go +++ /dev/null @@ -1,109 +0,0 @@ -package log - -import ( - "io" - "os" - "sync" - - "codeberg.org/eduVPN/eduvpn-common/internal/atomicfile" -) - -var ( - // MaxSize is the maximum size in bytes from when to start trimming - MaxSize int64 = 10 * 1024 * 1024 - - // TrimSize denotes how much to trim from the beginning - TrimSize = MaxSize / 2 - - // TrimMsg is the message to display when it was trimmed - TrimMsg = "--- previous part was trimmed by eduvpn-common as the file was too big (10MB) ---\n" -) - -// FileRotater is a file that is trimmed when a maximum size is reached -// This is for logging useful -type FileRotater struct { - filename string - file *os.File - mu sync.Mutex -} - -// NewFileRotater creates a new log file rotater -func NewFileRotater(filename string) (*FileRotater, error) { - fr := &FileRotater{ - filename: filename, - } - - err := fr.open() - if err != nil { - return nil, err - } - return fr, nil -} - -func (fr *FileRotater) open() error { - f, err := os.OpenFile(fr.filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o666) - if err != nil { - return err - } - fr.file = f - return nil -} - -func (fr *FileRotater) trim() error { - // We need to seek to the trim size to skip over that part as we discard it - _, err := fr.file.Seek(TrimSize, io.SeekStart) - if err != nil { - return err - } - - // get the part of the file that we want to keep - keep, err := io.ReadAll(fr.file) - if err != nil { - return err - } - - all := []byte(TrimMsg) - all = append(all, keep...) - err = atomicfile.WriteFile(fr.file.Name(), all, 0o666) - if err != nil { - return err - } - - // re-open the handle as the file was renamed - err = fr.file.Close() - if err != nil { - return err - } - err = fr.open() - if err != nil { - return err - } - return nil -} - -// Write implements io.Writer for the log rotater -func (fr *FileRotater) Write(p []byte) (n int, err error) { - fr.mu.Lock() - defer fr.mu.Unlock() - fi, err := fr.file.Stat() - if err != nil { - return 0, err - } - - if fi.Size() >= MaxSize { - err = fr.trim() - if err != nil { - return 0, err - } - } - // we don't write atomically here as we want it to be as fast as possible - // and if we lose a part of one log statement it's not a big deal - return fr.file.Write(p) -} - -// Close closes the file in a safe way by locking and unlocking the mutex -func (fr *FileRotater) Close() error { - fr.mu.Lock() - defer fr.mu.Unlock() - return fr.file.Close() -} diff --git a/internal/log/rotate_test.go b/internal/log/rotate_test.go deleted file mode 100644 index 4fa77fd..0000000 --- a/internal/log/rotate_test.go +++ /dev/null @@ -1,144 +0,0 @@ -package log - -import ( - "io" - "os" - "path/filepath" - "strings" - "sync" - "testing" -) - -func createFileRotater(t *testing.T) (*FileRotater, func()) { - d, err := os.MkdirTemp("", "logtest") - if err != nil { - t.Fatalf("failed creating tmp dir: %v", err) - } - fn := filepath.Join(d, "test.log") - fr, err := NewFileRotater(fn) - if err != nil { - t.Fatalf("NewFileRotater error: %v", err) - } - if fr == nil { - t.Fatal("NewFileRotater returned nil") - } - return fr, func() { - err := os.RemoveAll(d) - if err != nil { - t.Errorf("failed removing file: %v", err) - } - } -} - -func TestNewFileRotater(t *testing.T) { - _, cleanup := createFileRotater(t) - cleanup() - - d, err := os.MkdirTemp("", "anotherlogtest") - if err != nil { - t.Fatalf("failed creating another tmp dir: %v", err) - } - nef := filepath.Join(d, "notexist", "test.log") - _, err = NewFileRotater(nef) - if err == nil { - t.Error("NewFileRotater returned no error with nonexistent dir") - } -} - -func TestWriteConcurrent(t *testing.T) { - fr, cleanup := createFileRotater(t) - defer cleanup() - MaxSize = 5 - var wg sync.WaitGroup - for range 5 { - wg.Add(1) - go func() { - _, err := fr.Write([]byte("test")) - defer wg.Done() - if err != nil { - t.Errorf("concurrent write returned an error: %v", err) - } - }() - } - wg.Wait() -} - -func TestWriteTrim(t *testing.T) { - fr, cleanup := createFileRotater(t) - defer cleanup() - writeNCheckSize := func(n int, size int64) { - buf := make([]byte, n) - - for i := range n { - buf[i] = 'x' - } - _, err := fr.Write(buf) - if err != nil { - t.Fatalf("failed writing: %v", err) - } - - fs, err := fr.file.Stat() - if err != nil { - t.Fatalf("failed getting size: %v", err) - } - - gsize := fs.Size() - - if gsize != size { - t.Fatalf("got: %v, want: %v, max size: %v", gsize, size, MaxSize) - } - } - - // we test by writing a start message and checking if it disappears after trimmign - // the max size we set to the length of the start message + 20 bytes - begS := "this is the start" - begB := []byte(begS) - startN := int64(len(begB)) - MaxSize = startN + 20 - TrimSize = MaxSize / 2 - - // no trim yet - _, err := fr.Write(begB) - if err != nil { - t.Fatalf("failed writing start message: %v", err) - } - - // write until the trimming size - writeNCheckSize(5, startN+5) - writeNCheckSize(15, MaxSize) - - // set the length we want to write - var n int64 = 11 - - // now the size should be the length of the trimmed message plus the remaining (non-trimmed part of the file) plus the length we want to write - - size := int64(len(TrimMsg)) + (MaxSize - TrimSize) + n - writeNCheckSize(11, size) - - // disable trimming by setting it to a high value - MaxSize = 9000 - TrimSize = 9000 - - // now the size should be the old size plus the write size - newN := 12 - writeNCheckSize(newN, size+int64(newN)) - - _, err = fr.file.Seek(0, io.SeekStart) - if err != nil { - t.Fatalf("failed going to beginning of file: %v", err) - } - - b, err := io.ReadAll(fr.file) - if err != nil { - t.Fatalf("failed reading file: %v", err) - } - - corpus := string(b) - if strings.Contains(corpus, begS) { - t.Fatalf("file still contains beginning message: %v", corpus) - } - - if !strings.Contains(corpus, TrimMsg) { - t.Fatalf("file does not contain trim message: %v", corpus) - } -} diff --git a/internal/loglevel/loglevel.go b/internal/loglevel/loglevel.go new file mode 100644 index 0000000..74dd49f --- /dev/null +++ b/internal/loglevel/loglevel.go @@ -0,0 +1,37 @@ +// Package loglevel implements a basic level based logger +package loglevel + +import ( + "fmt" + "io" + "log/slog" + "os" + "path" +) + +type Logger struct { + fr *FileRotater +} + +func (l *Logger) Init(dir string) (*slog.Logger, error) { + err := os.MkdirAll(dir, 0o700) + if err != nil { + return nil, err + } + name := path.Join(dir, "log") + + fr, err := NewFileRotater(name) + if err != nil { + return nil, fmt.Errorf("failed creating log rotater: %w", err) + } + l.fr = fr + multi := io.MultiWriter(os.Stdout, fr) + handler := slog.NewTextHandler(multi, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }) + return slog.New(handler), nil +} + +func (l *Logger) Close() error { + return l.fr.Close() +} diff --git a/internal/loglevel/rotate.go b/internal/loglevel/rotate.go new file mode 100644 index 0000000..bfd8351 --- /dev/null +++ b/internal/loglevel/rotate.go @@ -0,0 +1,109 @@ +package loglevel + +import ( + "io" + "os" + "sync" + + "codeberg.org/eduVPN/eduvpn-common/internal/atomicfile" +) + +var ( + // MaxSize is the maximum size in bytes from when to start trimming + MaxSize int64 = 10 * 1024 * 1024 + + // TrimSize denotes how much to trim from the beginning + TrimSize = MaxSize / 2 + + // TrimMsg is the message to display when it was trimmed + TrimMsg = "--- previous part was trimmed by eduvpn-common as the file was too big (10MB) ---\n" +) + +// FileRotater is a file that is trimmed when a maximum size is reached +// This is for logging useful +type FileRotater struct { + filename string + file *os.File + mu sync.Mutex +} + +// NewFileRotater creates a new log file rotater +func NewFileRotater(filename string) (*FileRotater, error) { + fr := &FileRotater{ + filename: filename, + } + + err := fr.open() + if err != nil { + return nil, err + } + return fr, nil +} + +func (fr *FileRotater) open() error { + f, err := os.OpenFile(fr.filename, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o666) + if err != nil { + return err + } + fr.file = f + return nil +} + +func (fr *FileRotater) trim() error { + // We need to seek to the trim size to skip over that part as we discard it + _, err := fr.file.Seek(TrimSize, io.SeekStart) + if err != nil { + return err + } + + // get the part of the file that we want to keep + keep, err := io.ReadAll(fr.file) + if err != nil { + return err + } + + all := []byte(TrimMsg) + all = append(all, keep...) + err = atomicfile.WriteFile(fr.file.Name(), all, 0o666) + if err != nil { + return err + } + + // re-open the handle as the file was renamed + err = fr.file.Close() + if err != nil { + return err + } + err = fr.open() + if err != nil { + return err + } + return nil +} + +// Write implements io.Writer for the log rotater +func (fr *FileRotater) Write(p []byte) (n int, err error) { + fr.mu.Lock() + defer fr.mu.Unlock() + fi, err := fr.file.Stat() + if err != nil { + return 0, err + } + + if fi.Size() >= MaxSize { + err = fr.trim() + if err != nil { + return 0, err + } + } + // we don't write atomically here as we want it to be as fast as possible + // and if we lose a part of one log statement it's not a big deal + return fr.file.Write(p) +} + +// Close closes the file in a safe way by locking and unlocking the mutex +func (fr *FileRotater) Close() error { + fr.mu.Lock() + defer fr.mu.Unlock() + return fr.file.Close() +} diff --git a/internal/loglevel/rotate_test.go b/internal/loglevel/rotate_test.go new file mode 100644 index 0000000..d836330 --- /dev/null +++ b/internal/loglevel/rotate_test.go @@ -0,0 +1,144 @@ +package loglevel + +import ( + "io" + "os" + "path/filepath" + "strings" + "sync" + "testing" +) + +func createFileRotater(t *testing.T) (*FileRotater, func()) { + d, err := os.MkdirTemp("", "logtest") + if err != nil { + t.Fatalf("failed creating tmp dir: %v", err) + } + fn := filepath.Join(d, "test.log") + fr, err := NewFileRotater(fn) + if err != nil { + t.Fatalf("NewFileRotater error: %v", err) + } + if fr == nil { + t.Fatal("NewFileRotater returned nil") + } + return fr, func() { + err := os.RemoveAll(d) + if err != nil { + t.Errorf("failed removing file: %v", err) + } + } +} + +func TestNewFileRotater(t *testing.T) { + _, cleanup := createFileRotater(t) + cleanup() + + d, err := os.MkdirTemp("", "anotherlogtest") + if err != nil { + t.Fatalf("failed creating another tmp dir: %v", err) + } + nef := filepath.Join(d, "notexist", "test.log") + _, err = NewFileRotater(nef) + if err == nil { + t.Error("NewFileRotater returned no error with nonexistent dir") + } +} + +func TestWriteConcurrent(t *testing.T) { + fr, cleanup := createFileRotater(t) + defer cleanup() + MaxSize = 5 + var wg sync.WaitGroup + for range 5 { + wg.Add(1) + go func() { + _, err := fr.Write([]byte("test")) + defer wg.Done() + if err != nil { + t.Errorf("concurrent write returned an error: %v", err) + } + }() + } + wg.Wait() +} + +func TestWriteTrim(t *testing.T) { + fr, cleanup := createFileRotater(t) + defer cleanup() + writeNCheckSize := func(n int, size int64) { + buf := make([]byte, n) + + for i := range n { + buf[i] = 'x' + } + _, err := fr.Write(buf) + if err != nil { + t.Fatalf("failed writing: %v", err) + } + + fs, err := fr.file.Stat() + if err != nil { + t.Fatalf("failed getting size: %v", err) + } + + gsize := fs.Size() + + if gsize != size { + t.Fatalf("got: %v, want: %v, max size: %v", gsize, size, MaxSize) + } + } + + // we test by writing a start message and checking if it disappears after trimmign + // the max size we set to the length of the start message + 20 bytes + begS := "this is the start" + begB := []byte(begS) + startN := int64(len(begB)) + MaxSize = startN + 20 + TrimSize = MaxSize / 2 + + // no trim yet + _, err := fr.Write(begB) + if err != nil { + t.Fatalf("failed writing start message: %v", err) + } + + // write until the trimming size + writeNCheckSize(5, startN+5) + writeNCheckSize(15, MaxSize) + + // set the length we want to write + var n int64 = 11 + + // now the size should be the length of the trimmed message plus the remaining (non-trimmed part of the file) plus the length we want to write + + size := int64(len(TrimMsg)) + (MaxSize - TrimSize) + n + writeNCheckSize(11, size) + + // disable trimming by setting it to a high value + MaxSize = 9000 + TrimSize = 9000 + + // now the size should be the old size plus the write size + newN := 12 + writeNCheckSize(newN, size+int64(newN)) + + _, err = fr.file.Seek(0, io.SeekStart) + if err != nil { + t.Fatalf("failed going to beginning of file: %v", err) + } + + b, err := io.ReadAll(fr.file) + if err != nil { + t.Fatalf("failed reading file: %v", err) + } + + corpus := string(b) + if strings.Contains(corpus, begS) { + t.Fatalf("file still contains beginning message: %v", corpus) + } + + if !strings.Contains(corpus, TrimMsg) { + t.Fatalf("file does not contain trim message: %v", corpus) + } +} diff --git a/internal/server/custom.go b/internal/server/custom.go index a9a26b9..0837c86 100644 --- a/internal/server/custom.go +++ b/internal/server/custom.go @@ -5,8 +5,8 @@ import ( "log/slog" "time" - "codeberg.org/eduVPN/eduvpn-common/internal/api" "codeberg.org/eduVPN/eduvpn-common/internal/config/v2" + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi" "codeberg.org/eduVPN/eduvpn-common/types/server" "codeberg.org/jwijenbergh/eduoauth-go/v2" ) @@ -16,7 +16,7 @@ import ( // `id` is the identifier of the server, the base URL // `ot` specifies specifies the start time OAuth was already triggered func (s *Servers) AddCustom(ctx context.Context, id string, ot *int64) error { - sd := api.ServerData{ + sd := eduvpnapi.ServerData{ ID: id, Type: server.TypeCustom, BaseWK: id, @@ -40,7 +40,7 @@ func (s *Servers) AddCustom(ctx context.Context, id string, ot *int64) error { } // Authorize by creating the API object - _, err = api.NewAPI(ctx, s.clientID, sd, s.cb, nil) + _, err = eduvpnapi.NewAPI(ctx, s.clientID, sd, s.cb, nil) if err != nil { // authorization has failed, remove the server again rerr := s.config.RemoveServer(id, server.TypeCustom) @@ -58,7 +58,7 @@ func (s *Servers) AddCustom(ctx context.Context, id string, ot *int64) error { // `tok` are the tokens such that we can initialize the API // `disableAuth` is set to True when authorization should not be triggered func (s *Servers) GetCustom(ctx context.Context, id string, tok *eduoauth.Token, disableAuth bool) (*Server, error) { - sd := api.ServerData{ + sd := eduvpnapi.ServerData{ ID: id, Type: server.TypeCustom, BaseWK: id, @@ -71,7 +71,7 @@ func (s *Servers) GetCustom(ctx context.Context, id string, tok *eduoauth.Token, if err != nil { return nil, err } - a, err := api.NewAPI(ctx, s.clientID, sd, s.cb, tok) + a, err := eduvpnapi.NewAPI(ctx, s.clientID, sd, s.cb, tok) if err != nil { return nil, err } diff --git a/internal/server/institute.go b/internal/server/institute.go index c357a4d..9280d1e 100644 --- a/internal/server/institute.go +++ b/internal/server/institute.go @@ -5,9 +5,9 @@ import ( "log/slog" "time" - "codeberg.org/eduVPN/eduvpn-common/internal/api" "codeberg.org/eduVPN/eduvpn-common/internal/config/v2" "codeberg.org/eduVPN/eduvpn-common/internal/discovery" + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi" "codeberg.org/eduVPN/eduvpn-common/types/server" "codeberg.org/jwijenbergh/eduoauth-go/v2" ) @@ -24,7 +24,7 @@ func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, return err } - sd := api.ServerData{ + sd := eduvpnapi.ServerData{ ID: dsrv.BaseURL, Type: server.TypeInstituteAccess, BaseWK: dsrv.BaseURL, @@ -49,7 +49,7 @@ func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, } // Authorize by creating the API object - _, err = api.NewAPI(ctx, s.clientID, sd, s.cb, nil) + _, err = eduvpnapi.NewAPI(ctx, s.clientID, sd, s.cb, nil) if err != nil { // authorization has failed, remove the server again rerr := s.config.RemoveServer(dsrv.BaseURL, server.TypeInstituteAccess) @@ -79,7 +79,7 @@ func (s *Servers) GetInstitute(ctx context.Context, id string, disco *discovery. if err != nil { return nil, err } - sd := api.ServerData{ + sd := eduvpnapi.ServerData{ ID: dsrv.BaseURL, Type: server.TypeInstituteAccess, BaseWK: dsrv.BaseURL, @@ -87,7 +87,7 @@ func (s *Servers) GetInstitute(ctx context.Context, id string, disco *discovery. DisableAuthorize: disableAuth, } // Authorize by creating the API object - a, err := api.NewAPI(ctx, s.clientID, sd, s.cb, tok) + a, err := eduvpnapi.NewAPI(ctx, s.clientID, sd, s.cb, tok) if err != nil { return nil, err } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index d25750f..e97efbd 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -8,9 +8,9 @@ import ( "strings" "time" - "codeberg.org/eduVPN/eduvpn-common/internal/api" "codeberg.org/eduVPN/eduvpn-common/internal/config/v2" "codeberg.org/eduVPN/eduvpn-common/internal/discovery" + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi" "codeberg.org/eduVPN/eduvpn-common/types/server" "codeberg.org/jwijenbergh/eduoauth-go/v2" ) @@ -54,7 +54,7 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org return err } - sd := api.ServerData{ + sd := eduvpnapi.ServerData{ ID: dorg.OrgID, Type: server.TypeSecureInternet, BaseWK: dsrv.BaseURL, @@ -92,7 +92,7 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org } // Authorize by creating the API object - _, err = api.NewAPI(ctx, s.clientID, sd, s.cb, nil) + _, err = eduvpnapi.NewAPI(ctx, s.clientID, sd, s.cb, nil) if err != nil { // authorization has failed, remove the server again rerr := s.config.RemoveServer(orgID, server.TypeSecureInternet) @@ -126,7 +126,7 @@ func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery. return nil, err } - sd := api.ServerData{ + sd := eduvpnapi.ServerData{ ID: dorg.OrgID, Type: server.TypeSecureInternet, BaseWK: dloc.BaseURL, @@ -148,7 +148,7 @@ func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery. DisableAuthorize: disableAuth, } - a, err := api.NewAPI(ctx, s.clientID, sd, s.cb, tok) + a, err := eduvpnapi.NewAPI(ctx, s.clientID, sd, s.cb, tok) if err != nil { return nil, err } diff --git a/internal/server/server.go b/internal/server/server.go index 2372494..4623224 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -7,9 +7,9 @@ import ( "os" "time" - "codeberg.org/eduVPN/eduvpn-common/internal/api" - "codeberg.org/eduVPN/eduvpn-common/internal/api/profiles" v2 "codeberg.org/eduVPN/eduvpn-common/internal/config/v2" + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi" + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/profiles" "codeberg.org/eduVPN/eduvpn-common/types/protocol" srvtypes "codeberg.org/eduVPN/eduvpn-common/types/server" ) @@ -18,7 +18,7 @@ import ( type Server struct { identifier string t srvtypes.Type - apiw *api.API + apiw *eduvpnapi.API storage *v2.V2 } @@ -26,7 +26,7 @@ type Server struct { var ErrInvalidProfile = errors.New("invalid profile") // NewServer creates a new server -func (s *Servers) NewServer(identifier string, t srvtypes.Type, api *api.API) Server { +func (s *Servers) NewServer(identifier string, t srvtypes.Type, api *eduvpnapi.API) Server { return Server{ identifier: identifier, t: t, @@ -65,7 +65,7 @@ func (s *Server) FreshProfiles(ctx context.Context) (*profiles.Info, error) { return prfs, nil } -func (s *Server) api() (*api.API, error) { +func (s *Server) api() (*eduvpnapi.API, error) { if s.apiw == nil { return nil, errors.New("no API object found") } diff --git a/internal/server/servers.go b/internal/server/servers.go index bd22ffd..43716a4 100644 --- a/internal/server/servers.go +++ b/internal/server/servers.go @@ -5,9 +5,9 @@ import ( "errors" "fmt" - "codeberg.org/eduVPN/eduvpn-common/internal/api" "codeberg.org/eduVPN/eduvpn-common/internal/config/v2" "codeberg.org/eduVPN/eduvpn-common/internal/discovery" + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi" srvtypes "codeberg.org/eduVPN/eduvpn-common/types/server" "codeberg.org/jwijenbergh/eduoauth-go/v2" ) @@ -15,7 +15,7 @@ import ( // Callbacks defines the interface for doing certain callback operations type Callbacks interface { // api.Callbacks is the API callback interface - api.Callbacks + eduvpnapi.Callbacks // GettingConfig is called when the config is obtained GettingConfig() error // InvalidProfile is called when an invalid profile is found diff --git a/internal/test/server.go b/internal/test/server.go index 2596f29..82dd2f7 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -8,7 +8,7 @@ import ( "net/http" "net/http/httptest" - httpw "codeberg.org/eduVPN/eduvpn-common/internal/http" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" ) // Server wraps a HTTP test server @@ -71,7 +71,7 @@ func NewServerWithHandles(hps []HandlerPath, listener net.Listener) *Server { } // Client returns a test client that trusts the HTTPS certificates -func (srv *Server) Client() (*httpw.Client, error) { +func (srv *Server) Client() (*httpwrap.Client, error) { // Get the certs from the test server certs := x509.NewCertPool() for _, c := range srv.TLS.Certificates { @@ -91,6 +91,6 @@ func (srv *Server) Client() (*httpw.Client, error) { }, } // Override the client such that it only trusts the test server cert - httpC := httpw.NewClient(client) + httpC := httpwrap.NewClient(client) return httpC, nil } diff --git a/internal/version/version.go b/internal/version/version.go deleted file mode 100644 index 015a57b..0000000 --- a/internal/version/version.go +++ /dev/null @@ -1,9 +0,0 @@ -// Package version defines a version string that is used for: -// - building -// - the user agent -// - tagging -package version - -// Version is the latest version -// Update this when releasing -const Version = "4.0.0" diff --git a/make_release.sh b/make_release.sh index 02affb0..bbb633d 100755 --- a/make_release.sh +++ b/make_release.sh @@ -2,7 +2,7 @@ # This script was adapted from fkooman: https://git.sr.ht/~fkooman/vpn-daemon/tree/main/item/make_release.sh. Thanks! # -# Make a release of the version specified in internal/version/version.go and automatically release the artifacts +# Make a release of the version specified in internal/commonver/commonver.go and automatically release the artifacts # # Fail if error @@ -11,7 +11,7 @@ set -e echo "building $(git log -n 1 | head -n 1)" BRANCH="main" PROJECT_NAME=$(basename "${PWD}") -PROJECT_VERSION=$(grep -o 'const Version = "[^"]*' internal/version/version.go | cut -d '"' -f 2) +PROJECT_VERSION=$(grep -o 'const Version = "[^"]*' internal/commonver/commonver.go | cut -d '"' -f 2) PRERELEASE=false while [[ "$#" -gt 0 ]]; do diff --git a/prepare_release.sh b/prepare_release.sh index 61d0cfc..91ab23a 100755 --- a/prepare_release.sh +++ b/prepare_release.sh @@ -45,8 +45,8 @@ if [[ $(git diff) ]]; then fi # Replace version number -# replace in internal/version -sed -i "s/const Version = \".*\"/const Version = \"${PROJECT_VERSION}\"/" internal/version/version.go +# replace in internal/commonver +sed -i "s/const Version = \".*\"/const Version = \"${PROJECT_VERSION}\"/" internal/commonver/commonver.go sed -i "s/version = .*/version = ${PROJECT_VERSION}/" wrappers/python/setup.cfg sed -i "s/COMMON_VERSION = \".*\"/COMMON_VERSION = \"${PROJECT_VERSION}\"/" wrappers/python/setup.py sed -i "s/__version__ = \".*\"/__version__ = \"${PROJECT_VERSION}\"/" wrappers/python/eduvpn_common/__init__.py diff --git a/proxy/proxy.go b/proxy/proxy.go index af0cd89..b8a9ad9 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -12,7 +12,7 @@ import ( "codeberg.org/eduVPN/proxyguard" "codeberg.org/eduVPN/eduvpn-common/i18n/err" - httpw "codeberg.org/eduVPN/eduvpn-common/internal/http" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" ) // Logger is defined here such that we can update the proxyguard logger @@ -43,7 +43,7 @@ func NewProxyguard(ctx context.Context, lp int, tcpsp int, peer string, setupSoc ListenPort: lp, TCPSourcePort: tcpsp, SetupSocket: setupSocket, - UserAgent: httpw.UserAgent, + UserAgent: httpwrap.UserAgent, }, resChan: make(chan struct{}), } diff --git a/upload_release.sh b/upload_release.sh index a16d78d..b2537b9 100755 --- a/upload_release.sh +++ b/upload_release.sh @@ -14,7 +14,7 @@ fi ORG=eduVPN API_KEY=$(cat "$API_KEY_FILE") PROJECT_NAME=$(basename "$(pwd)") -PROJECT_VERSION=$(grep -o 'const Version = "[^"]*' internal/version/version.go | cut -d '"' -f 2) +PROJECT_VERSION=$(grep -o 'const Version = "[^"]*' internal/commonver/commonver.go | cut -d '"' -f 2) PRERELEASE=false while [[ "$#" -gt 0 ]]; do diff --git a/wrappers/python/Makefile b/wrappers/python/Makefile index c55faf8..4e7f8d7 100644 --- a/wrappers/python/Makefile +++ b/wrappers/python/Makefile @@ -1,7 +1,7 @@ .DEFAULT_GOAL := pack .PHONY: install-lib pack test clean lint fmt -VERSION := $(shell grep -o 'const Version = "[^"]*' ../../internal/version/version.go | cut -d '"' -f 2) +VERSION := $(shell grep -o 'const Version = "[^"]*' ../../internal/commonver/commonver.go | cut -d '"' -f 2) install-lib: rm -rf eduvpn_common/lib/* -- cgit v1.2.3