summaryrefslogtreecommitdiff
path: root/internal/eduvpnapi/eduvpnapi.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/eduvpnapi/eduvpnapi.go')
-rw-r--r--internal/eduvpnapi/eduvpnapi.go395
1 files changed, 395 insertions, 0 deletions
diff --git a/internal/eduvpnapi/eduvpnapi.go b/internal/eduvpnapi/eduvpnapi.go
new file mode 100644
index 0000000..62fe0d1
--- /dev/null
+++ b/internal/eduvpnapi/eduvpnapi.go
@@ -0,0 +1,395 @@
+// Package eduvpnapi implements version 3 of the eduVPN api: https://docs.eduvpn.org/server/v3/api.html
+package eduvpnapi
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "time"
+
+ "codeberg.org/jwijenbergh/eduoauth-go/v2"
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/endpoints"
+ "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/profiles"
+ "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap"
+ "codeberg.org/eduVPN/eduvpn-common/internal/wireguard"
+ "codeberg.org/eduVPN/eduvpn-common/types/protocol"
+ "codeberg.org/eduVPN/eduvpn-common/types/server"
+)
+
+// Callbacks is the API callback interface
+// It is used to trigger authorization and forward token updates
+type Callbacks interface {
+ // TriggerAuth is called when authorization should be triggered
+ TriggerAuth(context.Context, string, bool) (string, error)
+ // AuthDone is called when authorization has just completed
+ AuthDone(string, server.Type)
+ // TokensUpdates is called when tokens are updated
+ TokensUpdated(string, server.Type, eduoauth.Token)
+}
+
+// ServerData is the data for a server that is passed to the API struct
+type ServerData struct {
+ // ID is the identifier for the server
+ ID string
+ // Type is the type of server
+ Type server.Type
+ // BaseWK is the base well-known endpoint
+ BaseWK string
+ // BaseAuthWK is the base well-known endpoint for authorization. This is only different in case of secure internet
+ BaseAuthWK string
+ // ProcessAuth processes the OAuth authorization
+ ProcessAuth func(context.Context, string) (string, error)
+ // DisableAuthorize indicates whether or not new authorization requests should be disabled
+ DisableAuthorize bool
+ // transport is the HTTP transport, only used for testing currently
+ transport http.RoundTripper
+}
+
+// Transport returns the transport to be used for the server
+// By default it uses the transport from internal/httpwrap DefaultTransport
+func (s *ServerData) Transport() http.RoundTripper {
+ if s.transport == nil {
+ return httpwrap.DefaultTransport
+ }
+ return s.transport
+}
+
+// API is the top-level struct that each method is defined on
+type API struct {
+ cb Callbacks
+ // oauth is the oauth object
+ oauth *eduoauth.OAuth
+ // Data is the server data
+ Data ServerData
+}
+
+// NewAPI creates a new API object by creating an OAuth object
+func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, tokens *eduoauth.Token) (*API, error) {
+ cr := customRedirect(clientID)
+ // Construct OAuth
+
+ transp := sd.Transport()
+ post := true
+ // we do not support non-loopback clients with response_mode form_post
+ if cr != "" {
+ post = false
+ }
+ o := eduoauth.OAuth{
+ ClientID: clientID,
+ EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) {
+ ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, transp)
+ if err != nil {
+ return nil, err
+ }
+ return &eduoauth.EndpointResponse{
+ AuthorizationURL: ep.API.V3.Authorization,
+ TokenURL: ep.API.V3.Token,
+ }, nil
+ },
+ CustomRedirect: cr,
+ FormPost: post,
+ RedirectPath: "/callback",
+ TokensUpdated: func(tok eduoauth.Token) {
+ cb.TokensUpdated(sd.ID, sd.Type, tok)
+ },
+ Transport: transp,
+ UserAgent: httpwrap.UserAgent,
+ }
+
+ if tokens != nil {
+ o.UpdateTokens(*tokens)
+ }
+
+ api := &API{
+ cb: cb,
+ oauth: &o,
+ Data: sd,
+ }
+ err := api.authorize(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return api, nil
+}
+
+// ErrAuthorizeDisabled is returned when authorization is disabled but is needed to complete
+var ErrAuthorizeDisabled = errors.New("cannot authorize as re-authorization is disabled")
+
+func (a *API) authorize(ctx context.Context) (err error) {
+ _, err = a.oauth.AccessToken(ctx)
+ // already authorized
+ if err == nil {
+ return nil
+ }
+
+ // otherwise check if invalid tokens,
+ // if not then something else is wrong with the API
+ // return an error
+ tErr := &eduoauth.TokensInvalidError{}
+ if !errors.As(err, &tErr) {
+ return err
+ }
+
+ if a.Data.DisableAuthorize {
+ return ErrAuthorizeDisabled
+ }
+
+ defer func() {
+ if err == nil {
+ a.cb.AuthDone(a.Data.ID, a.Data.Type)
+ }
+ }()
+
+ scope := "config"
+ url, err := a.oauth.AuthURL(ctx, scope)
+ if err != nil {
+ return err
+ }
+ if a.Data.ProcessAuth != nil {
+ url, err = a.Data.ProcessAuth(ctx, url)
+ if err != nil {
+ return err
+ }
+ }
+ // We expect an uri if custom redirect is non empty
+ uri, err := a.cb.TriggerAuth(ctx, url, a.oauth.CustomRedirect != "")
+ if err != nil {
+ return err
+ }
+ // The uri is only given here if a custom redirect is done
+ err = a.oauth.Exchange(ctx, uri)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpwrap.OptionalParams) (http.Header, []byte, error) {
+ ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport())
+ if err != nil {
+ return nil, nil, err
+ }
+ u := ep.API.V3.API + endpoint
+
+ // TODO: Cache HTTP client?
+ httpC := httpwrap.NewClient(a.oauth.NewHTTPClient())
+ return httpC.Do(ctx, method, u, opts)
+}
+
+func (a *API) authorizedRetry(ctx context.Context, method string, endpoint string, opts *httpwrap.OptionalParams) (http.Header, []byte, error) {
+ h, body, err := a.authorized(ctx, method, endpoint, opts)
+ if err == nil {
+ return h, body, nil
+ }
+
+ statErr := &httpwrap.StatusError{}
+ // Only retry authorized if we get an HTTP 401
+ // TODO: Can the OAuth client handle this instead?
+ if errors.As(err, &statErr) && statErr.Status == 401 {
+ slog.Debug("Got a HTTP 401. Marking tokens as expired...", "HTTP method", method, "endpoint", endpoint)
+ // Mark the token as expired and retry, so we trigger the refresh flow
+ a.oauth.SetTokenExpired()
+ h, body, err = a.authorized(ctx, method, endpoint, opts)
+ }
+ // Tokens is invalid we need to renew and authorize again
+ tErr := &eduoauth.TokensInvalidError{}
+ if err != nil && errors.As(err, &tErr) {
+ // Mark the token as invalid and retry, so we trigger the authorization flow
+ a.oauth.SetTokenRenew()
+ slog.Debug("The tokens were invalid, trying again...")
+ if autherr := a.authorize(ctx); autherr != nil {
+ return nil, nil, autherr
+ }
+ return a.authorized(ctx, method, endpoint, opts)
+ }
+ return h, body, err
+}
+
+// Disconnect disconnects a client from the server by sending a /disconnect API call
+// This cleans up resources such as WireGuard IP allocation
+func (a *API) Disconnect(ctx context.Context) error {
+ _, _, err := a.authorized(ctx, http.MethodPost, "/disconnect", &httpwrap.OptionalParams{Timeout: 5 * time.Second})
+ return err
+}
+
+// Info does the /info API call
+func (a *API) Info(ctx context.Context) (*profiles.Info, error) {
+ _, body, err := a.authorizedRetry(ctx, http.MethodGet, "/info", nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed API /info: %w", err)
+ }
+ p := profiles.Info{}
+ if err = json.Unmarshal(body, &p); err != nil {
+ return nil, fmt.Errorf("failed API /info: %w", err)
+ }
+ return &p, nil
+}
+
+// ConnectData is the data that is returned when the /connect call completes without error
+type ConnectData struct {
+ // Configuration is the VPN configuration
+ Configuration string
+ // Protocol tells us what protocol it is, OpenVPN or WireGuard (proxied or not)
+ Protocol protocol.Protocol
+ // Expires tells us when this configuration expires
+ Expires time.Time
+}
+
+// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1
+func boolToYesNo(preferTCP bool) string {
+ if preferTCP {
+ return "yes"
+ }
+ return "no"
+}
+
+func protocolFromCT(ct string) (protocol.Protocol, error) {
+ switch ct {
+ case "application/x-wireguard-profile":
+ return protocol.WireGuard, nil
+ case "application/x-wireguard+tcp-profile":
+ return protocol.WireGuardProxy, nil
+ case "application/x-openvpn-profile":
+ return protocol.OpenVPN, nil
+ }
+ return protocol.Unknown, fmt.Errorf("invalid content type: %s", ct)
+}
+
+// ErrNoProtocols is returned when a connect call is given with an empty protocol slice
+var ErrNoProtocols = errors.New("no protocols supplied")
+
+// ErrUnknownProtocol is returned when the client in a connect gives an unknown protocol
+var ErrUnknownProtocol = errors.New("unknown protocol supplied")
+
+// Connect sends a /connect to an eduVPN server
+// `ctx` is the context used for cancellation
+// protos is the list of protocols supported and wanted by the client
+func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []protocol.Protocol, pTCP bool) (*ConnectData, error) {
+ hdrs := http.Header{
+ "content-type": {"application/x-www-form-urlencoded"},
+ }
+ uv := url.Values{
+ "profile_id": {prof.ID},
+ }
+
+ if len(protos) == 0 {
+ return nil, ErrNoProtocols
+ }
+
+ var wgKey *wgtypes.Key
+
+ // Loop over the protocols and set the correct headers and values
+ for _, p := range protos {
+ switch p {
+ case protocol.WireGuard:
+ gk, err := wgtypes.GeneratePrivateKey()
+ if err != nil {
+ return nil, err
+ }
+ wgKey = &gk
+ // Set the public key
+ pubkey := wgKey.PublicKey()
+ uv.Set("public_key", pubkey.String())
+ hdrs.Add("accept", "application/x-wireguard-profile")
+ hdrs.Add("accept", "application/x-wireguard+tcp-profile")
+ case protocol.OpenVPN:
+ hdrs.Add("accept", "application/x-openvpn-profile")
+ default:
+ return nil, ErrUnknownProtocol
+ }
+ }
+ // set prefer TCP
+ uv.Set("prefer_tcp", boolToYesNo(pTCP))
+
+ // Construct the parameters
+ params := &httpwrap.OptionalParams{Headers: hdrs, Body: uv}
+ h, body, err := a.authorizedRetry(ctx, http.MethodPost, "/connect", params)
+ if err != nil {
+ return nil, fmt.Errorf("failed API /connect call: %w", err)
+ }
+
+ // Parse expiry
+ expH := h.Get("expires")
+ if expH == "" {
+ return nil, errors.New("the server did not give an expires header")
+ }
+ expT, err := http.ParseTime(expH)
+ if err != nil {
+ return nil, fmt.Errorf("failed parsing expiry time: %w", err)
+ }
+
+ vpnCfg := string(body)
+ // Parse content type
+ contentH := h.Get("content-type")
+ proto, err := protocolFromCT(contentH)
+ if err != nil {
+ return nil, err
+ }
+
+ if proto == protocol.OpenVPN {
+ // ensure scripts are not ran by default by append script-security 0 to the config
+ vpnCfg += "\nscript-security 0"
+ return &ConnectData{
+ Configuration: vpnCfg,
+ Protocol: proto,
+ Expires: expT,
+ }, nil
+ }
+
+ vpnCfg, err = wireguard.Config(vpnCfg, wgKey)
+ if err != nil {
+ return nil, err
+ }
+ return &ConnectData{
+ Configuration: vpnCfg,
+ Protocol: proto,
+ Expires: expT,
+ }, nil
+}
+
+func getEndpoints(ctx context.Context, url string, tp http.RoundTripper) (*endpoints.Endpoints, error) {
+ uStr, err := httpwrap.JoinURLPath(url, "/.well-known/vpn-user-portal")
+ if err != nil {
+ return nil, err
+ }
+ httpC := httpwrap.NewClient(nil)
+ httpC.Client.Transport = tp
+ _, body, err := httpC.Get(ctx, uStr)
+ if err != nil {
+ return nil, fmt.Errorf("failed getting server endpoints with error: %w", err)
+ }
+
+ ep := endpoints.Endpoints{}
+ if err = json.Unmarshal(body, &ep); err != nil {
+ return nil, fmt.Errorf("failed getting server endpoints with error: %w", err)
+ }
+ err = ep.Validate()
+ if err != nil {
+ return nil, err
+ }
+ return &ep, nil
+}
+
+// OAuthLogger is defined here to update the internal logger
+// for the eduoauth library
+type OAuthLogger struct{}
+
+// Logf logs a message with parameters
+func (ol *OAuthLogger) Logf(msg string, params ...any) {
+ slog.Debug("OAuth log", "log", fmt.Sprintf(msg, params...))
+}
+
+// Log logs a message
+func (ol *OAuthLogger) Log(msg string) {
+ slog.Debug("OAuth log", "log", msg)
+}
+
+func init() {
+ eduoauth.UpdateLogger(&OAuthLogger{})
+}