diff options
Diffstat (limited to 'internal/api')
| -rw-r--r-- | internal/api/api.go | 362 | ||||
| -rw-r--r-- | internal/api/endpoints/endpoints.go | 51 | ||||
| -rw-r--r-- | internal/api/profiles/profiles.go | 82 | ||||
| -rw-r--r-- | internal/api/redirect.go | 23 |
4 files changed, 518 insertions, 0 deletions
diff --git a/internal/api/api.go b/internal/api/api.go new file mode 100644 index 0000000..c9f0408 --- /dev/null +++ b/internal/api/api.go @@ -0,0 +1,362 @@ +package api + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/jwijenbergh/eduoauth-go" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/eduvpn/eduvpn-common/internal/api/endpoints" + "github.com/eduvpn/eduvpn-common/internal/api/profiles" + httpw "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/internal/log" + "github.com/eduvpn/eduvpn-common/internal/wireguard" + "github.com/eduvpn/eduvpn-common/types/protocol" + "github.com/eduvpn/eduvpn-common/types/server" +) + +type Callbacks interface { + TriggerAuth(context.Context, string, bool) (string, error) + AuthDone(string, server.Type) + TokensUpdated(string, server.Type, eduoauth.Token) +} + +type ServerData struct { + // ID is the identifier for the server + ID string + // Type is the type of server + Type server.Type + // BaseWK is the base well-known endpoint + BaseWK string + // BaseAuthWK is the base well-known endpoint for authorization. This is only different in case of secure internet + BaseAuthWK string + // ProcessAuth processes the OAuth authorization + ProcessAuth func(string) string + // SetAuthorizeTime sets the authorization time + SetAuthorizeTime func(time.Time) + // DisableAuthorize indicates whether or not new authorization requests should be disabled + DisableAuthorize bool +} + +type API struct { + cb Callbacks + // oauth is the oauth object + oauth *eduoauth.OAuth + // apiURL is the API url to send a request to + apiURL string + // Data is the server data + Data ServerData +} + +// NewAPI creates a new API object by creating an OAuth object +func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, tokens *eduoauth.Token) (*API, error) { + ep, epauth, err := refreshEndpoints(ctx, sd) + if err != nil { + return nil, err + } + + cr := customRedirect(clientID) + // Construct OAuth + o := eduoauth.OAuth{ + ClientID: clientID, + BaseAuthorizationURL: epauth.Authorization, + TokenURL: epauth.Token, + CustomRedirect: cr, + RedirectPath: "/callback", + TokensUpdated: func(tok eduoauth.Token) { + cb.TokensUpdated(sd.ID, sd.Type, tok) + }, + } + + if tokens != nil { + o.UpdateTokens(*tokens) + } + + api := &API{ + cb: cb, + oauth: &o, + apiURL: ep.API, + Data: sd, + } + err = api.authorize(ctx) + if err != nil { + return nil, err + } + return api, nil +} + +var ErrAuthorizeDisabled = errors.New("cannot authorize as re-authorization is disabled") + +func (a *API) authorize(ctx context.Context) (err error) { + _, err = a.oauth.AccessToken(ctx) + // already authorized + if err == nil { + return nil + } + + if a.Data.DisableAuthorize { + return ErrAuthorizeDisabled + } + + defer func() { + if err == nil { + a.cb.AuthDone(a.Data.ID, a.Data.Type) + } + }() + + scope := "config" + url, err := a.oauth.AuthURL(scope) + if err != nil { + return err + } + if a.Data.ProcessAuth != nil { + url = a.Data.ProcessAuth(url) + } + // We expect an uri if custom redirect is non empty + uri, err := a.cb.TriggerAuth(ctx, url, a.oauth.CustomRedirect != "") + if err != nil { + return err + } + // The uri is only given here if a custom redirect is done + err = a.oauth.Exchange(ctx, uri) + if err != nil { + return err + } + return nil +} + +func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { + u := a.apiURL + endpoint + + // TODO: Cache HTTP client? + httpC := httpw.NewClient(a.oauth.NewHTTPClient()) + return httpC.Do(ctx, method, u, opts) +} + +func (a *API) authorizedRetry(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { + h, body, err := a.authorized(ctx, method, endpoint, opts) + if err == nil { + return h, body, nil + } + + statErr := &httpw.StatusError{} + // Only retry authorized if we get an HTTP 401 + // TODO: Can the OAuth client handle this instead? + if errors.As(err, &statErr) && statErr.Status == 401 { + log.Logger.Debugf("Got a 401 error after HTTP method: %s, endpoint: %s. Marking token as expired...", method, endpoint) + // Mark the token as expired and retry, so we trigger the refresh flow + a.oauth.SetTokenExpired() + h, body, err = a.authorized(ctx, method, endpoint, opts) + } + // Tokens is invalid we need to renew and authorize again + tErr := &eduoauth.TokensInvalidError{} + if err != nil && errors.As(err, &tErr) { + // Mark the token as invalid and retry, so we trigger the authorization flow + a.oauth.SetTokenRenew() + log.Logger.Debugf("the tokens were invalid, trying again...") + if autherr := a.authorize(ctx); autherr != nil { + return nil, nil, autherr + } + return a.authorized(ctx, method, endpoint, opts) + } + return h, body, err +} + +func (a *API) Disconnect(ctx context.Context) error { + _, _, err := a.authorized(ctx, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second}) + return err +} + +func (a *API) Info(ctx context.Context) (*profiles.Info, error) { + _, body, err := a.authorizedRetry(ctx, http.MethodGet, "/info", nil) + if err != nil { + return nil, fmt.Errorf("failed API /info: %v", err) + } + p := profiles.Info{} + if err = json.Unmarshal(body, &p); err != nil { + return nil, fmt.Errorf("failed API /info: %v", err) + } + return &p, nil +} + +type Proxy struct { + Listen string + Peer string +} + +type ConnectData struct { + Configuration string + Protocol protocol.Protocol + Expires time.Time + Proxy *wireguard.Proxy +} + +// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 +func boolToYesNo(preferTCP bool) string { + if preferTCP { + return "yes" + } + return "no" +} + +func protocolFromCT(ct string) (protocol.Protocol, error) { + switch ct { + case "application/x-wireguard-profile": + return protocol.WireGuard, nil + case "application/x-wireguard+tcp-profile": + return protocol.WireGuardTCP, nil + case "application/x-openvpn-profile": + return protocol.OpenVPN, nil + } + return protocol.Unknown, fmt.Errorf("invalid content type: %s", ct) +} + +// Connect sends a /connect to an eduVPN server +// `ctx` is the context used for cancellation +// protos is the list of protocols supported and wanted by the client +func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []protocol.Protocol, pTCP bool) (*ConnectData, error) { + hdrs := http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + } + uv := url.Values{ + "profile_id": {prof.ID}, + } + + if len(protos) == 0 { + return nil, errors.New("no protocols supplied") + } + + var wgKey *wgtypes.Key + + // Loop over the protocols and set the correct headers and values + for _, p := range protos { + switch p { + case protocol.WireGuard: + gk, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + wgKey = &gk + // Set the public key + pubkey := wgKey.PublicKey() + uv.Set("public_key", pubkey.String()) + if !pTCP { + hdrs.Add("accept", "application/x-wireguard-profile") + } + hdrs.Add("accept", "application/x-wireguard+tcp-profile") + case protocol.OpenVPN: + hdrs.Add("accept", "application/x-openvpn-profile") + default: + return nil, errors.New("unknown protocol supplied") + } + } + // set prefer TCP + uv.Set("prefer_tcp", boolToYesNo(pTCP)) + + // Construct the parameters + params := &httpw.OptionalParams{Headers: hdrs, Body: uv} + h, body, err := a.authorizedRetry(ctx, http.MethodPost, "/connect", params) + if err != nil { + return nil, fmt.Errorf("failed API /connect call: %v", err) + } + + // Parse expiry + expH := h.Get("expires") + expT, err := http.ParseTime(expH) + if err != nil { + return nil, fmt.Errorf("failed parsing expiry time: %v", err) + } + + vpnCfg := string(body) + // Parse content type + contentH := h.Get("content-type") + proto, err := protocolFromCT(contentH) + if err != nil { + return nil, err + } + + if proto == protocol.OpenVPN { + // ensure scripts are not ran by default by append script-security 0 to the config + vpnCfg += "\nscript-security 0" + return &ConnectData{ + Configuration: vpnCfg, + Protocol: proto, + Expires: expT, + }, nil + } + + vpnCfg, proxy, err := wireguard.Config(vpnCfg, wgKey, proto == protocol.WireGuardTCP) + if err != nil { + return nil, err + } + return &ConnectData{ + Configuration: vpnCfg, + Protocol: proto, + Expires: expT, + Proxy: proxy, + }, nil +} + +func getEndpoints(ctx context.Context, url string) (*endpoints.Endpoints, error) { + uStr, err := httpw.JoinURLPath(url, "/.well-known/vpn-user-portal") + if err != nil { + return nil, err + } + httpC := httpw.NewClient(nil) + _, body, err := httpC.Get(ctx, uStr) + if err != nil { + return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) + } + + ep := endpoints.Endpoints{} + if err = json.Unmarshal(body, &ep); err != nil { + return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) + } + err = ep.Validate() + if err != nil { + return nil, err + } + return &ep, nil +} + +func refreshEndpoints(ctx context.Context, sd ServerData) (*endpoints.List, *endpoints.List, error) { + // Get the endpoints + ep, err := getEndpoints(ctx, sd.BaseWK) + if err != nil { + return nil, nil, err + } + + // This is a mess but we essentially have to instantiate different endpoints if the authorization base URL is different from the base portal URL + // This happens with secure internet when the location is not equal to the home location + var epauth *endpoints.Endpoints + if sd.BaseAuthWK != sd.BaseWK { + oep, err := getEndpoints(ctx, sd.BaseAuthWK) + if err != nil { + return nil, nil, err + } + epauth = oep + } else { + epauth = ep + } + return &ep.API.V3, &epauth.API.V3, err +} + +type OAuthLogger struct{} + +func (ol *OAuthLogger) Logf(msg string, params ...interface{}) { + log.Logger.Debugf(msg, params...) +} + +func (ol *OAuthLogger) Log(msg string) { + log.Logger.Debugf("%s", msg) +} + +func init() { + eduoauth.UpdateLogger(&OAuthLogger{}) +} diff --git a/internal/api/endpoints/endpoints.go b/internal/api/endpoints/endpoints.go new file mode 100644 index 0000000..11e244b --- /dev/null +++ b/internal/api/endpoints/endpoints.go @@ -0,0 +1,51 @@ +package endpoints + +import ( + "fmt" + "net/url" +) + +type List struct { + API string `json:"api_endpoint"` + Authorization string `json:"authorization_endpoint"` + Token string `json:"token_endpoint"` +} + +type Versions struct { + V2 List `json:"http://eduvpn.org/api#2"` + V3 List `json:"http://eduvpn.org/api#3"` +} + +// Endpoints defines the json format for /.well-known/vpn-user-portal". +type Endpoints struct { + API Versions `json:"api"` + V string `json:"v"` +} + +// Validate validates the endpoints by parsing them and checking the scheme is HTTP +// An error is returned if they are not valid +func (e Endpoints) Validate() error { + v3 := e.API.V3 + pAPI, err := url.Parse(v3.API) + if err != nil { + return fmt.Errorf("failed to parse API endpoint: %w", err) + } + pAuth, err := url.Parse(v3.Authorization) + if err != nil { + return fmt.Errorf("failed to parse API authorization endpoint: %w", err) + } + pToken, err := url.Parse(v3.Token) + if err != nil { + return fmt.Errorf("failed to parse API token endpoint: %w", err) + } + if pAPI.Scheme != "https" { + return fmt.Errorf("API Scheme: '%s', is not equal to HTTPS", pAPI.Scheme) + } + if pAPI.Scheme != pAuth.Scheme { + return fmt.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme) + } + if pAPI.Scheme != pToken.Scheme { + return fmt.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme) + } + return nil +} diff --git a/internal/api/profiles/profiles.go b/internal/api/profiles/profiles.go new file mode 100644 index 0000000..2f4fed7 --- /dev/null +++ b/internal/api/profiles/profiles.go @@ -0,0 +1,82 @@ +package profiles + +import ( + "github.com/eduvpn/eduvpn-common/types/protocol" + "github.com/eduvpn/eduvpn-common/types/server" +) + +type Profile struct { + ID string `json:"profile_id"` + DisplayName string `json:"display_name"` + VPNProtoList []string `json:"vpn_proto_list"` + DefaultGateway bool `json:"default_gateway"` + DNSSearchDomains []string `json:"dns_search_domain_list"` +} + +type ListInfo struct { + ProfileList []Profile `json:"profile_list"` +} + +type Info struct { + Info ListInfo `json:"info"` +} + +func (i Info) Len() int { + return len(i.Info.ProfileList) +} + +func (i Info) Get(id string) *Profile { + for _, p := range i.Info.ProfileList { + if p.ID == id { + return &p + } + } + return nil +} + +func (i Info) MustIndex(n int) Profile { + return i.Info.ProfileList[n] +} + +func hasProtocol(protos []string, proto protocol.Protocol) bool { + for _, p := range protos { + if protocol.New(p) == proto { + return true + } + } + return false +} + +func (p *Profile) HasOpenVPN() bool { + return hasProtocol(p.VPNProtoList, protocol.OpenVPN) +} + +func (p *Profile) HasWireGuard() bool { + return hasProtocol(p.VPNProtoList, protocol.WireGuard) +} + +func (i Info) FilterWireGuard() *Info { + var ret []Profile + for _, p := range i.Info.ProfileList { + if !p.HasOpenVPN() { + continue + } + } + return &Info{ + Info: ListInfo{ + ProfileList: ret, + }, + } +} + +func (i Info) Public() server.Profiles { + m := make(map[string]server.Profile) + for _, p := range i.Info.ProfileList { + m[p.ID] = server.Profile{ + DisplayName: map[string]string{ + "en": p.DisplayName, + }, + } + } + return server.Profiles{Map: m} +} diff --git a/internal/api/redirect.go b/internal/api/redirect.go new file mode 100644 index 0000000..5d9e749 --- /dev/null +++ b/internal/api/redirect.go @@ -0,0 +1,23 @@ +package api + +// customRedirects supplies redirect values that should be handled by the app itself +// here we hardcode the redirect values that we should use in the OAuth requests +// these values were taken from https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php +var customRedirects = map[string]string{ + "org.letsconnect-vpn.app.ios": "org.letsconnect-vpn.app.ios:/api/callback", + "org.letsconnect-vpn.app.android": "org.letsconnect-vpn.app:/api/callback", + "org.eduvpn.app.ios": "org.eduvpn.app.ios:/api/callback", + "org.eduvpn.app.android": "org.eduvpn.app:/api/callback", +} + +// customRedirect returns the custom redirect string for the clientID `cid` +// Empty string if none is defined or one is defined but is empty. +// In both empty string cases, eduvpn-common handles the redirects as 127.0.0.1 local server redirects +// If a non-empty string is returned, the redirect should be handled by the client and we only use the redirect URI value in our OAuth requests +func customRedirect(cid string) string { + v, ok := customRedirects[cid] + if !ok { + return "" + } + return v +} |
