diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2024-02-06 16:27:45 +0100 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2024-02-19 14:15:07 +0100 |
| commit | a84050a5e93f5fb9f5bbb79ca21b37e8359cf289 (patch) | |
| tree | ecdf0cea81b0bd6a3cf669f2b31c45a222d1c5f5 | |
| parent | 3152078aec8334357a61171838f664eb03299211 (diff) | |
Server: Refactor internal server package to use new state file
This completely rewrites the internal server package. Some advantages:
- Caches less
- Uses a callback interface so that the client package does not get so
convoluted
- Introduce a new API package that only deals with the server API and
uses github.com/jwijenbergh/eduoauth-go
| -rw-r--r-- | internal/api/api.go | 362 | ||||
| -rw-r--r-- | internal/api/endpoints/endpoints.go | 51 | ||||
| -rw-r--r-- | internal/api/profiles/profiles.go | 82 | ||||
| -rw-r--r-- | internal/api/redirect.go | 23 | ||||
| -rw-r--r-- | internal/server/custom.go | 63 | ||||
| -rw-r--r-- | internal/server/institute.go | 73 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 91 | ||||
| -rw-r--r-- | internal/server/server.go | 355 | ||||
| -rw-r--r-- | internal/server/servers.go | 113 | ||||
| -rw-r--r-- | internal/server/time.go | 70 |
10 files changed, 1079 insertions, 204 deletions
diff --git a/internal/api/api.go b/internal/api/api.go new file mode 100644 index 0000000..c9f0408 --- /dev/null +++ b/internal/api/api.go @@ -0,0 +1,362 @@ +package api + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/jwijenbergh/eduoauth-go" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/eduvpn/eduvpn-common/internal/api/endpoints" + "github.com/eduvpn/eduvpn-common/internal/api/profiles" + httpw "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/internal/log" + "github.com/eduvpn/eduvpn-common/internal/wireguard" + "github.com/eduvpn/eduvpn-common/types/protocol" + "github.com/eduvpn/eduvpn-common/types/server" +) + +type Callbacks interface { + TriggerAuth(context.Context, string, bool) (string, error) + AuthDone(string, server.Type) + TokensUpdated(string, server.Type, eduoauth.Token) +} + +type ServerData struct { + // ID is the identifier for the server + ID string + // Type is the type of server + Type server.Type + // BaseWK is the base well-known endpoint + BaseWK string + // BaseAuthWK is the base well-known endpoint for authorization. This is only different in case of secure internet + BaseAuthWK string + // ProcessAuth processes the OAuth authorization + ProcessAuth func(string) string + // SetAuthorizeTime sets the authorization time + SetAuthorizeTime func(time.Time) + // DisableAuthorize indicates whether or not new authorization requests should be disabled + DisableAuthorize bool +} + +type API struct { + cb Callbacks + // oauth is the oauth object + oauth *eduoauth.OAuth + // apiURL is the API url to send a request to + apiURL string + // Data is the server data + Data ServerData +} + +// NewAPI creates a new API object by creating an OAuth object +func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, tokens *eduoauth.Token) (*API, error) { + ep, epauth, err := refreshEndpoints(ctx, sd) + if err != nil { + return nil, err + } + + cr := customRedirect(clientID) + // Construct OAuth + o := eduoauth.OAuth{ + ClientID: clientID, + BaseAuthorizationURL: epauth.Authorization, + TokenURL: epauth.Token, + CustomRedirect: cr, + RedirectPath: "/callback", + TokensUpdated: func(tok eduoauth.Token) { + cb.TokensUpdated(sd.ID, sd.Type, tok) + }, + } + + if tokens != nil { + o.UpdateTokens(*tokens) + } + + api := &API{ + cb: cb, + oauth: &o, + apiURL: ep.API, + Data: sd, + } + err = api.authorize(ctx) + if err != nil { + return nil, err + } + return api, nil +} + +var ErrAuthorizeDisabled = errors.New("cannot authorize as re-authorization is disabled") + +func (a *API) authorize(ctx context.Context) (err error) { + _, err = a.oauth.AccessToken(ctx) + // already authorized + if err == nil { + return nil + } + + if a.Data.DisableAuthorize { + return ErrAuthorizeDisabled + } + + defer func() { + if err == nil { + a.cb.AuthDone(a.Data.ID, a.Data.Type) + } + }() + + scope := "config" + url, err := a.oauth.AuthURL(scope) + if err != nil { + return err + } + if a.Data.ProcessAuth != nil { + url = a.Data.ProcessAuth(url) + } + // We expect an uri if custom redirect is non empty + uri, err := a.cb.TriggerAuth(ctx, url, a.oauth.CustomRedirect != "") + if err != nil { + return err + } + // The uri is only given here if a custom redirect is done + err = a.oauth.Exchange(ctx, uri) + if err != nil { + return err + } + return nil +} + +func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { + u := a.apiURL + endpoint + + // TODO: Cache HTTP client? + httpC := httpw.NewClient(a.oauth.NewHTTPClient()) + return httpC.Do(ctx, method, u, opts) +} + +func (a *API) authorizedRetry(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { + h, body, err := a.authorized(ctx, method, endpoint, opts) + if err == nil { + return h, body, nil + } + + statErr := &httpw.StatusError{} + // Only retry authorized if we get an HTTP 401 + // TODO: Can the OAuth client handle this instead? + if errors.As(err, &statErr) && statErr.Status == 401 { + log.Logger.Debugf("Got a 401 error after HTTP method: %s, endpoint: %s. Marking token as expired...", method, endpoint) + // Mark the token as expired and retry, so we trigger the refresh flow + a.oauth.SetTokenExpired() + h, body, err = a.authorized(ctx, method, endpoint, opts) + } + // Tokens is invalid we need to renew and authorize again + tErr := &eduoauth.TokensInvalidError{} + if err != nil && errors.As(err, &tErr) { + // Mark the token as invalid and retry, so we trigger the authorization flow + a.oauth.SetTokenRenew() + log.Logger.Debugf("the tokens were invalid, trying again...") + if autherr := a.authorize(ctx); autherr != nil { + return nil, nil, autherr + } + return a.authorized(ctx, method, endpoint, opts) + } + return h, body, err +} + +func (a *API) Disconnect(ctx context.Context) error { + _, _, err := a.authorized(ctx, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second}) + return err +} + +func (a *API) Info(ctx context.Context) (*profiles.Info, error) { + _, body, err := a.authorizedRetry(ctx, http.MethodGet, "/info", nil) + if err != nil { + return nil, fmt.Errorf("failed API /info: %v", err) + } + p := profiles.Info{} + if err = json.Unmarshal(body, &p); err != nil { + return nil, fmt.Errorf("failed API /info: %v", err) + } + return &p, nil +} + +type Proxy struct { + Listen string + Peer string +} + +type ConnectData struct { + Configuration string + Protocol protocol.Protocol + Expires time.Time + Proxy *wireguard.Proxy +} + +// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 +func boolToYesNo(preferTCP bool) string { + if preferTCP { + return "yes" + } + return "no" +} + +func protocolFromCT(ct string) (protocol.Protocol, error) { + switch ct { + case "application/x-wireguard-profile": + return protocol.WireGuard, nil + case "application/x-wireguard+tcp-profile": + return protocol.WireGuardTCP, nil + case "application/x-openvpn-profile": + return protocol.OpenVPN, nil + } + return protocol.Unknown, fmt.Errorf("invalid content type: %s", ct) +} + +// Connect sends a /connect to an eduVPN server +// `ctx` is the context used for cancellation +// protos is the list of protocols supported and wanted by the client +func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []protocol.Protocol, pTCP bool) (*ConnectData, error) { + hdrs := http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + } + uv := url.Values{ + "profile_id": {prof.ID}, + } + + if len(protos) == 0 { + return nil, errors.New("no protocols supplied") + } + + var wgKey *wgtypes.Key + + // Loop over the protocols and set the correct headers and values + for _, p := range protos { + switch p { + case protocol.WireGuard: + gk, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + wgKey = &gk + // Set the public key + pubkey := wgKey.PublicKey() + uv.Set("public_key", pubkey.String()) + if !pTCP { + hdrs.Add("accept", "application/x-wireguard-profile") + } + hdrs.Add("accept", "application/x-wireguard+tcp-profile") + case protocol.OpenVPN: + hdrs.Add("accept", "application/x-openvpn-profile") + default: + return nil, errors.New("unknown protocol supplied") + } + } + // set prefer TCP + uv.Set("prefer_tcp", boolToYesNo(pTCP)) + + // Construct the parameters + params := &httpw.OptionalParams{Headers: hdrs, Body: uv} + h, body, err := a.authorizedRetry(ctx, http.MethodPost, "/connect", params) + if err != nil { + return nil, fmt.Errorf("failed API /connect call: %v", err) + } + + // Parse expiry + expH := h.Get("expires") + expT, err := http.ParseTime(expH) + if err != nil { + return nil, fmt.Errorf("failed parsing expiry time: %v", err) + } + + vpnCfg := string(body) + // Parse content type + contentH := h.Get("content-type") + proto, err := protocolFromCT(contentH) + if err != nil { + return nil, err + } + + if proto == protocol.OpenVPN { + // ensure scripts are not ran by default by append script-security 0 to the config + vpnCfg += "\nscript-security 0" + return &ConnectData{ + Configuration: vpnCfg, + Protocol: proto, + Expires: expT, + }, nil + } + + vpnCfg, proxy, err := wireguard.Config(vpnCfg, wgKey, proto == protocol.WireGuardTCP) + if err != nil { + return nil, err + } + return &ConnectData{ + Configuration: vpnCfg, + Protocol: proto, + Expires: expT, + Proxy: proxy, + }, nil +} + +func getEndpoints(ctx context.Context, url string) (*endpoints.Endpoints, error) { + uStr, err := httpw.JoinURLPath(url, "/.well-known/vpn-user-portal") + if err != nil { + return nil, err + } + httpC := httpw.NewClient(nil) + _, body, err := httpC.Get(ctx, uStr) + if err != nil { + return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) + } + + ep := endpoints.Endpoints{} + if err = json.Unmarshal(body, &ep); err != nil { + return nil, fmt.Errorf("failed getting server endpoints with error: %w", err) + } + err = ep.Validate() + if err != nil { + return nil, err + } + return &ep, nil +} + +func refreshEndpoints(ctx context.Context, sd ServerData) (*endpoints.List, *endpoints.List, error) { + // Get the endpoints + ep, err := getEndpoints(ctx, sd.BaseWK) + if err != nil { + return nil, nil, err + } + + // This is a mess but we essentially have to instantiate different endpoints if the authorization base URL is different from the base portal URL + // This happens with secure internet when the location is not equal to the home location + var epauth *endpoints.Endpoints + if sd.BaseAuthWK != sd.BaseWK { + oep, err := getEndpoints(ctx, sd.BaseAuthWK) + if err != nil { + return nil, nil, err + } + epauth = oep + } else { + epauth = ep + } + return &ep.API.V3, &epauth.API.V3, err +} + +type OAuthLogger struct{} + +func (ol *OAuthLogger) Logf(msg string, params ...interface{}) { + log.Logger.Debugf(msg, params...) +} + +func (ol *OAuthLogger) Log(msg string) { + log.Logger.Debugf("%s", msg) +} + +func init() { + eduoauth.UpdateLogger(&OAuthLogger{}) +} diff --git a/internal/api/endpoints/endpoints.go b/internal/api/endpoints/endpoints.go new file mode 100644 index 0000000..11e244b --- /dev/null +++ b/internal/api/endpoints/endpoints.go @@ -0,0 +1,51 @@ +package endpoints + +import ( + "fmt" + "net/url" +) + +type List struct { + API string `json:"api_endpoint"` + Authorization string `json:"authorization_endpoint"` + Token string `json:"token_endpoint"` +} + +type Versions struct { + V2 List `json:"http://eduvpn.org/api#2"` + V3 List `json:"http://eduvpn.org/api#3"` +} + +// Endpoints defines the json format for /.well-known/vpn-user-portal". +type Endpoints struct { + API Versions `json:"api"` + V string `json:"v"` +} + +// Validate validates the endpoints by parsing them and checking the scheme is HTTP +// An error is returned if they are not valid +func (e Endpoints) Validate() error { + v3 := e.API.V3 + pAPI, err := url.Parse(v3.API) + if err != nil { + return fmt.Errorf("failed to parse API endpoint: %w", err) + } + pAuth, err := url.Parse(v3.Authorization) + if err != nil { + return fmt.Errorf("failed to parse API authorization endpoint: %w", err) + } + pToken, err := url.Parse(v3.Token) + if err != nil { + return fmt.Errorf("failed to parse API token endpoint: %w", err) + } + if pAPI.Scheme != "https" { + return fmt.Errorf("API Scheme: '%s', is not equal to HTTPS", pAPI.Scheme) + } + if pAPI.Scheme != pAuth.Scheme { + return fmt.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme) + } + if pAPI.Scheme != pToken.Scheme { + return fmt.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme) + } + return nil +} diff --git a/internal/api/profiles/profiles.go b/internal/api/profiles/profiles.go new file mode 100644 index 0000000..2f4fed7 --- /dev/null +++ b/internal/api/profiles/profiles.go @@ -0,0 +1,82 @@ +package profiles + +import ( + "github.com/eduvpn/eduvpn-common/types/protocol" + "github.com/eduvpn/eduvpn-common/types/server" +) + +type Profile struct { + ID string `json:"profile_id"` + DisplayName string `json:"display_name"` + VPNProtoList []string `json:"vpn_proto_list"` + DefaultGateway bool `json:"default_gateway"` + DNSSearchDomains []string `json:"dns_search_domain_list"` +} + +type ListInfo struct { + ProfileList []Profile `json:"profile_list"` +} + +type Info struct { + Info ListInfo `json:"info"` +} + +func (i Info) Len() int { + return len(i.Info.ProfileList) +} + +func (i Info) Get(id string) *Profile { + for _, p := range i.Info.ProfileList { + if p.ID == id { + return &p + } + } + return nil +} + +func (i Info) MustIndex(n int) Profile { + return i.Info.ProfileList[n] +} + +func hasProtocol(protos []string, proto protocol.Protocol) bool { + for _, p := range protos { + if protocol.New(p) == proto { + return true + } + } + return false +} + +func (p *Profile) HasOpenVPN() bool { + return hasProtocol(p.VPNProtoList, protocol.OpenVPN) +} + +func (p *Profile) HasWireGuard() bool { + return hasProtocol(p.VPNProtoList, protocol.WireGuard) +} + +func (i Info) FilterWireGuard() *Info { + var ret []Profile + for _, p := range i.Info.ProfileList { + if !p.HasOpenVPN() { + continue + } + } + return &Info{ + Info: ListInfo{ + ProfileList: ret, + }, + } +} + +func (i Info) Public() server.Profiles { + m := make(map[string]server.Profile) + for _, p := range i.Info.ProfileList { + m[p.ID] = server.Profile{ + DisplayName: map[string]string{ + "en": p.DisplayName, + }, + } + } + return server.Profiles{Map: m} +} diff --git a/internal/api/redirect.go b/internal/api/redirect.go new file mode 100644 index 0000000..5d9e749 --- /dev/null +++ b/internal/api/redirect.go @@ -0,0 +1,23 @@ +package api + +// customRedirects supplies redirect values that should be handled by the app itself +// here we hardcode the redirect values that we should use in the OAuth requests +// these values were taken from https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php +var customRedirects = map[string]string{ + "org.letsconnect-vpn.app.ios": "org.letsconnect-vpn.app.ios:/api/callback", + "org.letsconnect-vpn.app.android": "org.letsconnect-vpn.app:/api/callback", + "org.eduvpn.app.ios": "org.eduvpn.app.ios:/api/callback", + "org.eduvpn.app.android": "org.eduvpn.app:/api/callback", +} + +// customRedirect returns the custom redirect string for the clientID `cid` +// Empty string if none is defined or one is defined but is empty. +// In both empty string cases, eduvpn-common handles the redirects as 127.0.0.1 local server redirects +// If a non-empty string is returned, the redirect should be handled by the client and we only use the redirect URI value in our OAuth requests +func customRedirect(cid string) string { + v, ok := customRedirects[cid] + if !ok { + return "" + } + return v +} diff --git a/internal/server/custom.go b/internal/server/custom.go new file mode 100644 index 0000000..b4b81cb --- /dev/null +++ b/internal/server/custom.go @@ -0,0 +1,63 @@ +package server + +import ( + "context" + "time" + + "github.com/eduvpn/eduvpn-common/internal/api" + "github.com/eduvpn/eduvpn-common/internal/config/v2" + "github.com/eduvpn/eduvpn-common/types/server" + "github.com/jwijenbergh/eduoauth-go" +) + +func (s *Servers) AddCustom(ctx context.Context, id string, na bool) (*Server, error) { + sd := api.ServerData{ + ID: id, + Type: server.TypeCustom, + BaseWK: id, + BaseAuthWK: id, + } + + var a *api.API + var err error + if !na { + // Authorize by creating the API object + a, err = api.NewAPI(ctx, s.clientID, sd, s.cb, nil) + if err != nil { + return nil, err + } + } + + err = s.config.AddServer(id, server.TypeCustom, v2.Server{LastAuthorizeTime: time.Now()}) + if err != nil { + return nil, err + } + + cust := s.NewServer(id, server.TypeCustom, a) + // Return the server with the API + + return &cust, nil +} + +func (s *Servers) GetCustom(ctx context.Context, id string, tok *eduoauth.Token, disableAuth bool) (*Server, error) { + sd := api.ServerData{ + ID: id, + Type: server.TypeCustom, + BaseWK: id, + BaseAuthWK: id, + DisableAuthorize: disableAuth, + } + + // Get the server from the config + _, err := s.config.GetServer(id, server.TypeCustom) + if err != nil { + return nil, err + } + a, err := api.NewAPI(ctx, s.clientID, sd, s.cb, tok) + if err != nil { + return nil, err + } + + cust := s.NewServer(id, server.TypeCustom, a) + return &cust, nil +} diff --git a/internal/server/institute.go b/internal/server/institute.go new file mode 100644 index 0000000..881f96d --- /dev/null +++ b/internal/server/institute.go @@ -0,0 +1,73 @@ +package server + +import ( + "context" + "time" + + "github.com/eduvpn/eduvpn-common/internal/api" + "github.com/eduvpn/eduvpn-common/internal/config/v2" + "github.com/eduvpn/eduvpn-common/internal/discovery" + "github.com/eduvpn/eduvpn-common/types/server" + "github.com/jwijenbergh/eduoauth-go" +) + +func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, id string, na bool) (*Server, error) { + // This is basically done to double check if the server is part of the institute access section of disco + dsrv, err := disco.ServerByURL(id, "institute_access") + if err != nil { + return nil, err + } + + sd := api.ServerData{ + ID: dsrv.BaseURL, + Type: server.TypeInstituteAccess, + BaseWK: dsrv.BaseURL, + BaseAuthWK: dsrv.BaseURL, + } + + var a *api.API + if !na { + // Authorize by creating the API object + a, err = api.NewAPI(ctx, s.clientID, sd, s.cb, nil) + if err != nil { + return nil, err + } + } + + err = s.config.AddServer(dsrv.BaseURL, server.TypeInstituteAccess, v2.Server{LastAuthorizeTime: time.Now()}) + if err != nil { + return nil, err + } + + inst := s.NewServer(dsrv.BaseURL, server.TypeInstituteAccess, a) + return &inst, nil +} + +func (s *Servers) GetInstitute(ctx context.Context, id string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) { + // This is basically done to double check if the server is part of the institute access section of disco + dsrv, err := disco.ServerByURL(id, "institute_access") + if err != nil { + return nil, err + } + + // Get the server from the config + _, err = s.config.GetServer(dsrv.BaseURL, server.TypeInstituteAccess) + if err != nil { + return nil, err + } + sd := api.ServerData{ + ID: dsrv.BaseURL, + Type: server.TypeInstituteAccess, + BaseWK: dsrv.BaseURL, + BaseAuthWK: dsrv.BaseURL, + DisableAuthorize: disableAuth, + } + // Authorize by creating the API object + a, err := api.NewAPI(ctx, s.clientID, sd, s.cb, tok) + if err != nil { + return nil, err + } + + inst := s.NewServer(dsrv.BaseURL, server.TypeInstituteAccess, a) + return &inst, nil +} diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go new file mode 100644 index 0000000..19e75a1 --- /dev/null +++ b/internal/server/secureinternet.go @@ -0,0 +1,91 @@ +package server + +import ( + "context" + "errors" + "time" + + "github.com/eduvpn/eduvpn-common/internal/api" + "github.com/eduvpn/eduvpn-common/internal/config/v2" + "github.com/eduvpn/eduvpn-common/internal/discovery" + "github.com/eduvpn/eduvpn-common/internal/util" + "github.com/eduvpn/eduvpn-common/types/server" + "github.com/jwijenbergh/eduoauth-go" +) + +func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, orgID string, na bool) (*Server, error) { + if s.config.HasSecureInternet() { + return nil, errors.New("a secure internet server already exists") + } + dorg, dsrv, err := disco.SecureHomeArgs(orgID) + if err != nil { + // We mark the organizations as expired because we got an error + // Note that in the docs it states that it only should happen when the Org ID doesn't exist + // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found + disco.MarkOrganizationsExpired() + return nil, err + } + + sd := api.ServerData{ + ID: orgID, + Type: server.TypeSecureInternet, + BaseWK: dsrv.BaseURL, + BaseAuthWK: dsrv.BaseURL, + ProcessAuth: func(url string) string { + return util.ReplaceWAYF(dsrv.AuthenticationURLTemplate, url, dorg.OrgID) + }, + } + + var a *api.API + if !na { + // Authorize by creating the API object + a, err = api.NewAPI(ctx, s.clientID, sd, s.cb, nil) + if err != nil { + return nil, err + } + } + + err = s.config.AddServer(orgID, server.TypeSecureInternet, v2.Server{CountryCode: dsrv.CountryCode, LastAuthorizeTime: time.Now()}) + if err != nil { + return nil, err + } + + sec := s.NewServer(orgID, server.TypeSecureInternet, a) + return &sec, nil +} + +func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) { + srv, err := s.config.GetServer(orgID, server.TypeSecureInternet) + if err != nil { + return nil, err + } + + dorg, dhome, err := disco.SecureHomeArgs(orgID) + if err != nil { + return nil, err + } + + dloc, err := disco.ServerByCountryCode(srv.CountryCode) + if err != nil { + return nil, err + } + + sd := api.ServerData{ + ID: dorg.OrgID, + Type: server.TypeSecureInternet, + BaseWK: dloc.BaseURL, + BaseAuthWK: dhome.BaseURL, + ProcessAuth: func(url string) string { + return util.ReplaceWAYF(dhome.AuthenticationURLTemplate, url, dorg.OrgID) + }, + DisableAuthorize: disableAuth, + } + + a, err := api.NewAPI(ctx, s.clientID, sd, s.cb, tok) + if err != nil { + return nil, err + } + + sec := s.NewServer(orgID, server.TypeSecureInternet, a) + return &sec, nil +} diff --git a/internal/server/server.go b/internal/server/server.go index 1bdef28..97dafff 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,280 +2,227 @@ package server import ( "context" + "errors" "os" "time" - "github.com/eduvpn/eduvpn-common/internal/discovery" - "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/internal/server/api" - "github.com/eduvpn/eduvpn-common/internal/server/base" - "github.com/eduvpn/eduvpn-common/internal/server/profile" - "github.com/eduvpn/eduvpn-common/internal/wireguard" + "github.com/eduvpn/eduvpn-common/internal/api" + "github.com/eduvpn/eduvpn-common/internal/api/profiles" + v2 "github.com/eduvpn/eduvpn-common/internal/config/v2" "github.com/eduvpn/eduvpn-common/types/protocol" srvtypes "github.com/eduvpn/eduvpn-common/types/server" - "github.com/go-errors/errors" ) -type Server interface { - // OAuth returns the struct used for OAuth - OAuth() *oauth.OAuth - - // TemplateAuth returns the authorization URL template function - TemplateAuth() func(string) string - - // Base returns the server base - Base() (*base.Base, error) - - // NeedsLocation checks if the server needs a secure internet location - NeedsLocation() bool - - // Public returns the representation that will be passed over the CGO barrier - Public() (interface{}, error) - - // RefreshEndpoints refreshes the endpoints for the server - RefreshEndpoints(context.Context, *discovery.Discovery) error +type Server struct { + identifier string + t srvtypes.Type + apiw *api.API + storage *v2.V2 } -// Name gets the name for the server and falls back to a default of "Unknown Server" -func Name(srv Server) string { - n := "Unknown Server" - if b, err := srv.Base(); err == nil { - n = b.URL - } - return n -} - -func UpdateTokens(srv Server, t oauth.Token) { - srv.OAuth().UpdateTokens(t) -} - -func OAuthURL(srv Server, name string, cr string) (string, error) { - return srv.OAuth().AuthURL(name, srv.TemplateAuth(), cr) -} +var ErrInvalidProfile = errors.New("invalid profile") -func OAuthExchange(ctx context.Context, srv Server, uri string) error { - return srv.OAuth().Exchange(ctx, uri) -} - -func HeaderToken(ctx context.Context, srv Server) (string, error) { - return srv.OAuth().AccessToken(ctx) -} - -func MarkTokenExpired(srv Server) { - srv.OAuth().SetTokenExpired() -} - -func MarkTokensForRenew(srv Server) { - srv.OAuth().SetTokenRenew() -} - -func NeedsRelogin(ctx context.Context, srv Server) bool { - // TODO: this error can be a context cancel - _, err := HeaderToken(ctx, srv) - return err != nil +func (s *Servers) NewServer(identifier string, t srvtypes.Type, api *api.API) Server { + return Server{ + identifier: identifier, + t: t, + apiw: api, + storage: s.config, + } } -func CurrentProfile(srv Server) (*profile.Profile, error) { - b, err := srv.Base() +// Profiles gets the profiles for the server +// It only does a /info network request if the profiles have not been cached +// force indicates whether or not the profiles should be fetched fresh +func (s *Server) Profiles(ctx context.Context) (*profiles.Info, error) { + a, err := s.api() if err != nil { return nil, err } - pID := b.Profiles.Current - for _, profile := range b.Profiles.Info.ProfileList { - if profile.ID == pID { - return &profile, nil - } + // Otherwise get fresh profiles and set the cache + prfs, err := a.Info(ctx) + if err != nil { + return nil, err } - - return nil, errors.Errorf("profile not found: " + pID) -} - -func ValidProfiles(srv Server, wireguardSupport bool) (*profile.Info, error) { - // No error wrapping here otherwise we wrap it too much - b, err := srv.Base() + err = s.SetProfileList(prfs.Public()) if err != nil { return nil, err } - ps := b.Profiles.Supported(wireguardSupport) - if len(ps) == 0 { - return nil, errors.Errorf("no profiles found with supported protocols") + return prfs, nil +} + +func (s *Server) api() (*api.API, error) { + if s.apiw == nil { + return nil, errors.New("no API object found") } - return &profile.Info{ - Current: b.Profiles.Current, - Info: profile.ListInfo{ - ProfileList: ps, - }, - }, nil + return s.apiw, nil } -func Profile(srv Server, id string) error { - b, err := srv.Base() +func (s *Server) findProfile(ctx context.Context, wgSupport bool) (*profiles.Profile, error) { + // Get the profiles by ignoring the cache + prfs, err := s.Profiles(ctx) if err != nil { - return err + return nil, err } - if !b.Profiles.Has(id) { - return errors.Errorf("no profile available with id: %s", id) + + // No profiles available + if prfs.Len() == 0 { + return nil, errors.New("the server has no available profiles for your account") } - b.Profiles.Current = id - return nil -} -type ConfigData struct { - // The configuration - Config string + // No WireGuard support, we have to filter the profiles that only have WireGuard + if !wgSupport { + prfs = prfs.FilterWireGuard() + } - // The type of configuration - Type string -} + var chosenP profiles.Profile -// Public gets the public data from the types package -// dg specifies if this config is default gateway -func (c *ConfigData) Public(dg bool) srvtypes.Configuration { - return srvtypes.Configuration{ - VPNConfig: c.Config, - Protocol: protocol.New(c.Type), - DefaultGateway: dg, + n := prfs.Len() + switch n { + // If we now get no profiles then that means a profile with only WireGuard was removed + case 0: + return nil, errors.New("the server has only WireGuard profiles but the client does not support WireGuard") + case 1: + // Only one profile, make sure it is set + chosenP = prfs.MustIndex(0) + default: + // Profile doesn't exist + prID, err := s.ProfileID() + if err != nil { + return nil, err + } + v := prfs.Get(prID) + if v == nil { + return nil, ErrInvalidProfile + } + chosenP = *v } + return &chosenP, nil } -func wireguardGetConfig(ctx context.Context, srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) { - b, err := srv.Base() +func (s *Server) connect(ctx context.Context, wgSupport bool, pTCP bool) (*srvtypes.Configuration, error) { + a, err := s.api() if err != nil { return nil, err } - pID := b.Profiles.Current - key, err := wireguard.GenerateKey() + // find a suitable profile to connect + chosenP, err := s.findProfile(ctx, wgSupport) if err != nil { return nil, err } - - pub := key.PublicKey().String() - cfg, proto, exp, err := api.ConnectWireguard(ctx, b, srv.OAuth(), pID, pub, preferTCP, openVPNSupport) + err = s.SetProfileID(chosenP.ID) if err != nil { return nil, err } - // Store start and end time - b.StartTime = time.Now() - b.EndTime = exp - - if proto == "wireguard" { - // This needs the go code a way to identify a connection - // Use the uuid of the connection e.g. on Linux - // This needs the client code to call the go code - - cfg = wireguard.ConfigAddKey(cfg, key) + protos := []protocol.Protocol{protocol.OpenVPN} + if wgSupport { + protos = append(protos, protocol.WireGuard) } - - return &ConfigData{Config: cfg, Type: proto}, nil -} - -func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigData, error) { - b, err := srv.Base() + // If the client supports WireGuard and the profile supports both protocols we remove openvpn from client support if EDUVPN_PREFER_WG is set to "1" + // This also only happens if prefer TCP is set to false + // TODO: remove the prefer TCP check when we have implemented proxyguard + if wgSupport && os.Getenv("EDUVPN_PREFER_WG") == "1" { + if chosenP.HasWireGuard() && chosenP.HasOpenVPN() { + protos = []protocol.Protocol{protocol.WireGuard} + } + } + // SAFETY: chosenP is guaranteed to be non-nil + apicfg, err := a.Connect(ctx, *chosenP, protos, pTCP) if err != nil { return nil, err } - pid := b.Profiles.Current - cfg, exp, err := api.ConnectOpenVPN(ctx, b, srv.OAuth(), pid, preferTCP) + err = s.SetExpireTime(apicfg.Expires) if err != nil { return nil, err } - - // Store start and end time - b.StartTime = time.Now() - b.EndTime = exp - - return &ConfigData{Config: cfg, Type: "openvpn"}, nil + var proxy *srvtypes.Proxy + if apicfg.Proxy != nil { + proxy = &srvtypes.Proxy{ + SourcePort: apicfg.Proxy.SourcePort, + Listen: apicfg.Proxy.Listen, + Peer: apicfg.Proxy.Peer, + } + } + return &srvtypes.Configuration{ + VPNConfig: apicfg.Configuration, + Protocol: apicfg.Protocol, + DefaultGateway: chosenP.DefaultGateway, + DNSSearchDomains: chosenP.DNSSearchDomains, + Proxy: proxy, + }, nil } -func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bool, error) { - b, err := srv.Base() +func (s *Server) Disconnect(ctx context.Context) error { + a, err := s.api() if err != nil { - return false, err - } - // Get new profiles using the info call - // This does not override the current profile - err = api.Info(ctx, b, srv.OAuth()) - if err != nil { - return false, err + return err } + return a.Disconnect(ctx) +} - // If there was a profile chosen and it doesn't exist anymore, reset it - if b.Profiles.Current != "" { - if _, err = CurrentProfile(srv); err != nil { - b.Profiles.Current = "" - } +func (s *Server) cfgServer() (*v2.Server, error) { + if s.storage == nil { + return nil, errors.New("cannot get server, no configuration passed") } + return s.storage.GetServer(s.identifier, s.t) +} - // there are multiple profiles and no selection has been made - if len(b.Profiles.Info.ProfileList) != 1 && b.Profiles.Current == "" { - return false, nil +func (s *Server) SetProfileID(id string) error { + cs, err := s.cfgServer() + if err != nil { + return err } + cs.Profiles.Current = id + return nil +} - // Set the current profile if there is only one profile or profile is already selected - // Set the first profile if none is selected - if b.Profiles.Current == "" { - b.Profiles.Current = b.Profiles.Info.ProfileList[0].ID - } - p, err := CurrentProfile(srv) - // shouldn't happen +func (s *Server) SetProfileList(prfs srvtypes.Profiles) error { + cs, err := s.cfgServer() if err != nil { - return false, err - } - // Profile does not support OpenVPN but the client also doesn't support WireGuard - if !p.SupportsOpenVPN() && !wireguardSupport { - return false, nil + return err } - return true, nil + cs.Profiles.Map = prfs.Map + return nil } -func Config(ctx context.Context, server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) { - p, err := CurrentProfile(server) +func (s *Server) SetExpireTime(et time.Time) error { + cs, err := s.cfgServer() if err != nil { - return nil, err + return err } + cs.ExpireTime = et + return nil +} - ovpn := p.SupportsOpenVPN() - wg := p.SupportsWireguard() && wireguardSupport - - // If we don't prefer TCP and this profile and client supports wireguard, - // we disable openvpn if the EDUVPN_PREFER_WG environment variable is set - // This is useful to force WireGuard if the profile supports both OpenVPN and WireGuard but the server still prefers OpenVPN - if !preferTCP && wg { - if os.Getenv("EDUVPN_PREFER_WG") == "1" { - ovpn = false - } +func (s *Server) ProfileID() (string, error) { + cs, err := s.cfgServer() + if err != nil { + return "", err } + return cs.Profiles.Current, nil +} - var cfg *ConfigData - - switch { - // The config supports wireguard and optionally openvpn - case wg: - // A wireguard connect call needs to generate a wireguard key and add it to the config - // Also the server could send back an OpenVPN config if it supports OpenVPN - cfg, err = wireguardGetConfig(ctx, server, preferTCP, ovpn) - // The config only supports OpenVPN - case ovpn: - cfg, err = openVPNGetConfig(ctx, server, preferTCP) - // The config supports no available protocol because the profile only supports WireGuard but the client doesn't - default: - return nil, errors.New("no supported protocol found") +func (s *Server) SetLocation(loc string) error { + if s.t != srvtypes.TypeSecureInternet { + return errors.New("changing secure internet location is only possible when the server is a secure location") } - - // Add script security 0 to disable OpenVPN scripts - // The client may override this but we provide the default protection here - if err == nil && cfg.Type == "openvpn" { - cfg.Config += "\nscript-security 0" + cs, err := s.cfgServer() + if err != nil { + return err } - return cfg, err + cs.CountryCode = loc + return nil } -func Disconnect(ctx context.Context, server Server) error { - b, err := server.Base() - if err != nil { - return err +func (s *Server) SetCurrent() error { + if s.storage == nil { + return errors.New("no storage available") } - return api.Disconnect(ctx, b, server.OAuth()) + s.storage.LastChosen = &v2.ServerType{ + ID: s.identifier, + T: s.t, + } + return nil } diff --git a/internal/server/servers.go b/internal/server/servers.go new file mode 100644 index 0000000..fe2550c --- /dev/null +++ b/internal/server/servers.go @@ -0,0 +1,113 @@ +package server + +import ( + "context" + "errors" + "fmt" + + "github.com/eduvpn/eduvpn-common/internal/api" + "github.com/eduvpn/eduvpn-common/internal/config/v2" + "github.com/eduvpn/eduvpn-common/internal/discovery" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" + "github.com/jwijenbergh/eduoauth-go" +) + +type Callbacks interface { + api.Callbacks + GettingConfig() error + InvalidProfile(context.Context, *Server) (string, error) +} + +type Servers struct { + clientID string + cb Callbacks + WGSupport bool + config *v2.V2 +} + +func (s *Servers) Remove(identifier string, t srvtypes.Type) error { + return s.config.RemoveServer(identifier, t) +} + +func NewServers(name string, cb Callbacks, wgSupport bool, cfg *v2.V2) Servers { + return Servers{ + clientID: name, + cb: cb, + WGSupport: wgSupport, + config: cfg, + } +} + +type CurrentServer struct { + *v2.Server + T v2.ServerType + srvs *Servers +} + +func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, disco *discovery.Discovery, tokens *eduoauth.Token, disableAuth bool) (*Server, error) { + switch cs.T.T { + case srvtypes.TypeInstituteAccess: + return cs.srvs.GetInstitute(ctx, cs.T.ID, disco, tokens, disableAuth) + case srvtypes.TypeSecureInternet: + return cs.srvs.GetSecure(ctx, cs.T.ID, disco, tokens, disableAuth) + case srvtypes.TypeCustom: + return cs.srvs.GetCustom(ctx, cs.T.ID, tokens, disableAuth) + default: + return nil, fmt.Errorf("no such server type: %d", cs.T.T) + } +} + +func (s *Servers) GetServer(id string, t srvtypes.Type) (*v2.Server, error) { + if s.config == nil { + return nil, errors.New("no configuration available") + } + return s.config.GetServer(id, t) +} + +func (s *Servers) CurrentServer() (*CurrentServer, error) { + curr, k, err := s.config.CurrentServer() + if err != nil { + return nil, err + } + return &CurrentServer{ + Server: curr, + T: *k, + srvs: s, + }, nil +} + +func (s *Servers) PublicCurrent(disco *discovery.Discovery) (*srvtypes.Current, error) { + return s.config.PublicCurrent(disco) +} + +func (s *Servers) ConnectWithCallbacks(ctx context.Context, srv *Server, pTCP bool) (*srvtypes.Configuration, error) { + err := srv.SetCurrent() + if err != nil { + return nil, err + } + err = s.cb.GettingConfig() + if err != nil { + return nil, err + } + cfg, err := srv.connect(ctx, s.WGSupport, pTCP) + if err == nil { + return cfg, nil + } + if !errors.Is(err, ErrInvalidProfile) { + return cfg, err + } + // Get a new profile from the callback + pr, err := s.cb.InvalidProfile(ctx, srv) + if err != nil { + return cfg, err + } + err = srv.SetProfileID(pr) + if err != nil { + return nil, err + } + err = s.cb.GettingConfig() + if err != nil { + return nil, err + } + return srv.connect(ctx, s.WGSupport, pTCP) +} diff --git a/internal/server/time.go b/internal/server/time.go new file mode 100644 index 0000000..4e54ef6 --- /dev/null +++ b/internal/server/time.go @@ -0,0 +1,70 @@ +package server + +import "time" + +// RenewButtonTime returns the time when the renew button should be shown for the server +// Implemented according to: https://github.com/eduvpn/documentation/blob/cdf4d054f7652d74e4192494e8bb0e21040e46ac/API.md#session-expiry +func RenewButtonTime(st time.Time, et time.Time) int64 { + d := et.Sub(st) + + // If the time is less than 24 hours (a day left), show it when 30 minutes have passed or on expired if less than 30 minutes + dayl := time.Duration(24 * time.Hour) + if d < dayl { + // Get the minimum time to add, 30 minutes or on expired + m := time.Duration(30 * time.Minute) + // The total delta time is larger, return that we should show the button after 30 minutes + if d > m { + return st.Add(30 * time.Minute).Unix() + } + // Just show it on expired + return st.Add(d).Unix() + } + + // Else just show it when 24 hours is left + // This is the delta minus 24 hours left as that's how long it takes for a day to be left in the expiry + // We thus add this to the start time + tillDay := d - dayl + t := st.Add(tillDay) + return t.Unix() +} + +func CountdownTime(st time.Time, et time.Time) int64 { + d := et.Sub(st) + + dayl := time.Duration(24 * time.Hour) + + // This is just the last 24 hours + // if less than or equal to 24 hours, immediately + if d <= dayl { + return st.Unix() + } + + tillDay := d - dayl + t := st.Add(tillDay) + return t.Unix() +} + +func NotificationTimes(st time.Time, et time.Time) []int64 { + last := []time.Duration{ + time.Duration(0), + time.Duration(1 * time.Hour), + time.Duration(2 * time.Hour), + time.Duration(4 * time.Hour), + } + + var t []int64 + + d := et.Sub(st) + for _, l := range last { + // If the notification remaining time is more than the total delta, continue + if l > d { + continue + } + // calculating the time till a notification must happen + tillN := d - l + // Get absolute time when this notification must be shown by adding the delta + c := st.Add(tillN) + t = append(t, c.Unix()) + } + return t +} |
