diff options
Diffstat (limited to 'internal/eduvpnapi/eduvpnapi.go')
| -rw-r--r-- | internal/eduvpnapi/eduvpnapi.go | 395 |
1 files changed, 395 insertions, 0 deletions
diff --git a/internal/eduvpnapi/eduvpnapi.go b/internal/eduvpnapi/eduvpnapi.go new file mode 100644 index 0000000..62fe0d1 --- /dev/null +++ b/internal/eduvpnapi/eduvpnapi.go @@ -0,0 +1,395 @@ +// Package eduvpnapi implements version 3 of the eduVPN api: https://docs.eduvpn.org/server/v3/api.html +package eduvpnapi + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "net/url" + "time" + + "codeberg.org/jwijenbergh/eduoauth-go/v2" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/endpoints" + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/profiles" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" + "codeberg.org/eduVPN/eduvpn-common/internal/wireguard" + "codeberg.org/eduVPN/eduvpn-common/types/protocol" + "codeberg.org/eduVPN/eduvpn-common/types/server" +) + +// Callbacks is the API callback interface +// It is used to trigger authorization and forward token updates +type Callbacks interface { + // TriggerAuth is called when authorization should be triggered + TriggerAuth(context.Context, string, bool) (string, error) + // AuthDone is called when authorization has just completed + AuthDone(string, server.Type) + // TokensUpdates is called when tokens are updated + TokensUpdated(string, server.Type, eduoauth.Token) +} + +// ServerData is the data for a server that is passed to the API struct +type ServerData struct { + // ID is the identifier for the server + ID string + // Type is the type of server + Type server.Type + // BaseWK is the base well-known endpoint + BaseWK string + // BaseAuthWK is the base well-known endpoint for authorization. This is only different in case of secure internet + BaseAuthWK string + // ProcessAuth processes the OAuth authorization + ProcessAuth func(context.Context, string) (string, error) + // DisableAuthorize indicates whether or not new authorization requests should be disabled + DisableAuthorize bool + // transport is the HTTP transport, only used for testing currently + transport http.RoundTripper +} + +// Transport returns the transport to be used for the server +// By default it uses the transport from internal/httpwrap DefaultTransport +func (s *ServerData) Transport() http.RoundTripper { + if s.transport == nil { + return httpwrap.DefaultTransport + } + return s.transport +} + +// API is the top-level struct that each method is defined on +type API struct { + cb Callbacks + // oauth is the oauth object + oauth *eduoauth.OAuth + // Data is the server data + Data ServerData +} + +// NewAPI creates a new API object by creating an OAuth object +func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, tokens *eduoauth.Token) (*API, error) { + cr := customRedirect(clientID) + // Construct OAuth + + transp := sd.Transport() + post := true + // we do not support non-loopback clients with response_mode form_post + if cr != "" { + post = false + } + o := eduoauth.OAuth{ + ClientID: clientID, + EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) { + ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK, transp) + if err != nil { + return nil, err + } + return &eduoauth.EndpointResponse{ + AuthorizationURL: ep.API.V3.Authorization, + TokenURL: ep.API.V3.Token, + }, nil + }, + CustomRedirect: cr, + FormPost: post, + RedirectPath: "/callback", + TokensUpdated: func(tok eduoauth.Token) { + cb.TokensUpdated(sd.ID, sd.Type, tok) + }, + Transport: transp, + UserAgent: httpwrap.UserAgent, + } + + if tokens != nil { + o.UpdateTokens(*tokens) + } + + api := &API{ + cb: cb, + oauth: &o, + Data: sd, + } + err := api.authorize(ctx) + if err != nil { + return nil, err + } + return api, nil +} + +// ErrAuthorizeDisabled is returned when authorization is disabled but is needed to complete +var ErrAuthorizeDisabled = errors.New("cannot authorize as re-authorization is disabled") + +func (a *API) authorize(ctx context.Context) (err error) { + _, err = a.oauth.AccessToken(ctx) + // already authorized + if err == nil { + return nil + } + + // otherwise check if invalid tokens, + // if not then something else is wrong with the API + // return an error + tErr := &eduoauth.TokensInvalidError{} + if !errors.As(err, &tErr) { + return err + } + + if a.Data.DisableAuthorize { + return ErrAuthorizeDisabled + } + + defer func() { + if err == nil { + a.cb.AuthDone(a.Data.ID, a.Data.Type) + } + }() + + scope := "config" + url, err := a.oauth.AuthURL(ctx, scope) + if err != nil { + return err + } + if a.Data.ProcessAuth != nil { + url, err = a.Data.ProcessAuth(ctx, url) + if err != nil { + return err + } + } + // We expect an uri if custom redirect is non empty + uri, err := a.cb.TriggerAuth(ctx, url, a.oauth.CustomRedirect != "") + if err != nil { + return err + } + // The uri is only given here if a custom redirect is done + err = a.oauth.Exchange(ctx, uri) + if err != nil { + return err + } + return nil +} + +func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpwrap.OptionalParams) (http.Header, []byte, error) { + ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK, a.Data.Transport()) + if err != nil { + return nil, nil, err + } + u := ep.API.V3.API + endpoint + + // TODO: Cache HTTP client? + httpC := httpwrap.NewClient(a.oauth.NewHTTPClient()) + return httpC.Do(ctx, method, u, opts) +} + +func (a *API) authorizedRetry(ctx context.Context, method string, endpoint string, opts *httpwrap.OptionalParams) (http.Header, []byte, error) { + h, body, err := a.authorized(ctx, method, endpoint, opts) + if err == nil { + return h, body, nil + } + + statErr := &httpwrap.StatusError{} + // Only retry authorized if we get an HTTP 401 + // TODO: Can the OAuth client handle this instead? + if errors.As(err, &statErr) && statErr.Status == 401 { + slog.Debug("Got a HTTP 401. Marking tokens as expired...", "HTTP method", method, "endpoint", endpoint) + // Mark the token as expired and retry, so we trigger the refresh flow + a.oauth.SetTokenExpired() + h, body, err = a.authorized(ctx, method, endpoint, opts) + } + // Tokens is invalid we need to renew and authorize again + tErr := &eduoauth.TokensInvalidError{} + if err != nil && errors.As(err, &tErr) { + // Mark the token as invalid and retry, so we trigger the authorization flow + a.oauth.SetTokenRenew() + slog.Debug("The tokens were invalid, trying again...") + if autherr := a.authorize(ctx); autherr != nil { + return nil, nil, autherr + } + return a.authorized(ctx, method, endpoint, opts) + } + return h, body, err +} + +// Disconnect disconnects a client from the server by sending a /disconnect API call +// This cleans up resources such as WireGuard IP allocation +func (a *API) Disconnect(ctx context.Context) error { + _, _, err := a.authorized(ctx, http.MethodPost, "/disconnect", &httpwrap.OptionalParams{Timeout: 5 * time.Second}) + return err +} + +// Info does the /info API call +func (a *API) Info(ctx context.Context) (*profiles.Info, error) { + _, body, err := a.authorizedRetry(ctx, http.MethodGet, "/info", nil) + if err != nil { + return nil, fmt.Errorf("failed API /info: %w", err) + } + p := profiles.Info{} + if err = json.Unmarshal(body, &p); err != nil { + return nil, fmt.Errorf("failed API /info: %w", err) + } + return &p, nil +} + +// ConnectData is the data that is returned when the /connect call completes without error +type ConnectData struct { + // Configuration is the VPN configuration + Configuration string + // Protocol tells us what protocol it is, OpenVPN or WireGuard (proxied or not) + Protocol protocol.Protocol + // Expires tells us when this configuration expires + Expires time.Time +} + +// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 +func boolToYesNo(preferTCP bool) string { + if preferTCP { + return "yes" + } + return "no" +} + +func protocolFromCT(ct string) (protocol.Protocol, error) { + switch ct { + case "application/x-wireguard-profile": + return protocol.WireGuard, nil + case "application/x-wireguard+tcp-profile": + return protocol.WireGuardProxy, nil + case "application/x-openvpn-profile": + return protocol.OpenVPN, nil + } + return protocol.Unknown, fmt.Errorf("invalid content type: %s", ct) +} + +// ErrNoProtocols is returned when a connect call is given with an empty protocol slice +var ErrNoProtocols = errors.New("no protocols supplied") + +// ErrUnknownProtocol is returned when the client in a connect gives an unknown protocol +var ErrUnknownProtocol = errors.New("unknown protocol supplied") + +// Connect sends a /connect to an eduVPN server +// `ctx` is the context used for cancellation +// protos is the list of protocols supported and wanted by the client +func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []protocol.Protocol, pTCP bool) (*ConnectData, error) { + hdrs := http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + } + uv := url.Values{ + "profile_id": {prof.ID}, + } + + if len(protos) == 0 { + return nil, ErrNoProtocols + } + + var wgKey *wgtypes.Key + + // Loop over the protocols and set the correct headers and values + for _, p := range protos { + switch p { + case protocol.WireGuard: + gk, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + wgKey = &gk + // Set the public key + pubkey := wgKey.PublicKey() + uv.Set("public_key", pubkey.String()) + hdrs.Add("accept", "application/x-wireguard-profile") + hdrs.Add("accept", "application/x-wireguard+tcp-profile") + case protocol.OpenVPN: + hdrs.Add("accept", "application/x-openvpn-profile") + default: + return nil, ErrUnknownProtocol + } + } + // set prefer TCP + uv.Set("prefer_tcp", boolToYesNo(pTCP)) + + // Construct the parameters + params := &httpwrap.OptionalParams{Headers: hdrs, Body: uv} + h, body, err := a.authorizedRetry(ctx, http.MethodPost, "/connect", params) + if err != nil { + return nil, fmt.Errorf("failed API /connect call: %w", err) + } + + // Parse expiry + expH := h.Get("expires") + if expH == "" { + return nil, errors.New("the server did not give an expires header") + } + expT, err := http.ParseTime(expH) + if err != nil { + return nil, fmt.Errorf("failed parsing expiry time: %w", err) + } + + vpnCfg := string(body) + // Parse content type + contentH := h.Get("content-type") + proto, err := protocolFromCT(contentH) + if err != nil { + return nil, err + } + + if proto == protocol.OpenVPN { + // ensure scripts are not ran by default by append script-security 0 to the config + vpnCfg += "\nscript-security 0" + return &ConnectData{ + Configuration: vpnCfg, + Protocol: proto, + Expires: expT, + }, nil + } + + vpnCfg, err = wireguard.Config(vpnCfg, wgKey) + if err != nil { + return nil, err + } + return &ConnectData{ + Configuration: vpnCfg, + Protocol: proto, + Expires: expT, + }, nil +} + +func getEndpoints(ctx context.Context, url string, tp http.RoundTripper) (*endpoints.Endpoints, error) { + uStr, err := httpwrap.JoinURLPath(url, "/.well-known/vpn-user-portal") + if err != nil { + return nil, err + } + httpC := httpwrap.NewClient(nil) + httpC.Client.Transport = tp + _, body, err := httpC.Get(ctx, uStr) + if err != nil { + return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) + } + + ep := endpoints.Endpoints{} + if err = json.Unmarshal(body, &ep); err != nil { + return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) + } + err = ep.Validate() + if err != nil { + return nil, err + } + return &ep, nil +} + +// OAuthLogger is defined here to update the internal logger +// for the eduoauth library +type OAuthLogger struct{} + +// Logf logs a message with parameters +func (ol *OAuthLogger) Logf(msg string, params ...any) { + slog.Debug("OAuth log", "log", fmt.Sprintf(msg, params...)) +} + +// Log logs a message +func (ol *OAuthLogger) Log(msg string) { + slog.Debug("OAuth log", "log", msg) +} + +func init() { + eduoauth.UpdateLogger(&OAuthLogger{}) +} |
