summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/api/api.go362
-rw-r--r--internal/api/endpoints/endpoints.go51
-rw-r--r--internal/api/profiles/profiles.go82
-rw-r--r--internal/api/redirect.go23
-rw-r--r--internal/server/custom.go63
-rw-r--r--internal/server/institute.go73
-rw-r--r--internal/server/secureinternet.go91
-rw-r--r--internal/server/server.go355
-rw-r--r--internal/server/servers.go113
-rw-r--r--internal/server/time.go70
10 files changed, 1079 insertions, 204 deletions
diff --git a/internal/api/api.go b/internal/api/api.go
new file mode 100644
index 0000000..c9f0408
--- /dev/null
+++ b/internal/api/api.go
@@ -0,0 +1,362 @@
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+ "time"
+
+ "github.com/jwijenbergh/eduoauth-go"
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/eduvpn/eduvpn-common/internal/api/endpoints"
+ "github.com/eduvpn/eduvpn-common/internal/api/profiles"
+ httpw "github.com/eduvpn/eduvpn-common/internal/http"
+ "github.com/eduvpn/eduvpn-common/internal/log"
+ "github.com/eduvpn/eduvpn-common/internal/wireguard"
+ "github.com/eduvpn/eduvpn-common/types/protocol"
+ "github.com/eduvpn/eduvpn-common/types/server"
+)
+
+type Callbacks interface {
+ TriggerAuth(context.Context, string, bool) (string, error)
+ AuthDone(string, server.Type)
+ TokensUpdated(string, server.Type, eduoauth.Token)
+}
+
+type ServerData struct {
+ // ID is the identifier for the server
+ ID string
+ // Type is the type of server
+ Type server.Type
+ // BaseWK is the base well-known endpoint
+ BaseWK string
+ // BaseAuthWK is the base well-known endpoint for authorization. This is only different in case of secure internet
+ BaseAuthWK string
+ // ProcessAuth processes the OAuth authorization
+ ProcessAuth func(string) string
+ // SetAuthorizeTime sets the authorization time
+ SetAuthorizeTime func(time.Time)
+ // DisableAuthorize indicates whether or not new authorization requests should be disabled
+ DisableAuthorize bool
+}
+
+type API struct {
+ cb Callbacks
+ // oauth is the oauth object
+ oauth *eduoauth.OAuth
+ // apiURL is the API url to send a request to
+ apiURL string
+ // Data is the server data
+ Data ServerData
+}
+
+// NewAPI creates a new API object by creating an OAuth object
+func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, tokens *eduoauth.Token) (*API, error) {
+ ep, epauth, err := refreshEndpoints(ctx, sd)
+ if err != nil {
+ return nil, err
+ }
+
+ cr := customRedirect(clientID)
+ // Construct OAuth
+ o := eduoauth.OAuth{
+ ClientID: clientID,
+ BaseAuthorizationURL: epauth.Authorization,
+ TokenURL: epauth.Token,
+ CustomRedirect: cr,
+ RedirectPath: "/callback",
+ TokensUpdated: func(tok eduoauth.Token) {
+ cb.TokensUpdated(sd.ID, sd.Type, tok)
+ },
+ }
+
+ if tokens != nil {
+ o.UpdateTokens(*tokens)
+ }
+
+ api := &API{
+ cb: cb,
+ oauth: &o,
+ apiURL: ep.API,
+ Data: sd,
+ }
+ err = api.authorize(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return api, nil
+}
+
+var ErrAuthorizeDisabled = errors.New("cannot authorize as re-authorization is disabled")
+
+func (a *API) authorize(ctx context.Context) (err error) {
+ _, err = a.oauth.AccessToken(ctx)
+ // already authorized
+ if err == nil {
+ return nil
+ }
+
+ if a.Data.DisableAuthorize {
+ return ErrAuthorizeDisabled
+ }
+
+ defer func() {
+ if err == nil {
+ a.cb.AuthDone(a.Data.ID, a.Data.Type)
+ }
+ }()
+
+ scope := "config"
+ url, err := a.oauth.AuthURL(scope)
+ if err != nil {
+ return err
+ }
+ if a.Data.ProcessAuth != nil {
+ url = a.Data.ProcessAuth(url)
+ }
+ // We expect an uri if custom redirect is non empty
+ uri, err := a.cb.TriggerAuth(ctx, url, a.oauth.CustomRedirect != "")
+ if err != nil {
+ return err
+ }
+ // The uri is only given here if a custom redirect is done
+ err = a.oauth.Exchange(ctx, uri)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) {
+ u := a.apiURL + endpoint
+
+ // TODO: Cache HTTP client?
+ httpC := httpw.NewClient(a.oauth.NewHTTPClient())
+ return httpC.Do(ctx, method, u, opts)
+}
+
+func (a *API) authorizedRetry(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) {
+ h, body, err := a.authorized(ctx, method, endpoint, opts)
+ if err == nil {
+ return h, body, nil
+ }
+
+ statErr := &httpw.StatusError{}
+ // Only retry authorized if we get an HTTP 401
+ // TODO: Can the OAuth client handle this instead?
+ if errors.As(err, &statErr) && statErr.Status == 401 {
+ log.Logger.Debugf("Got a 401 error after HTTP method: %s, endpoint: %s. Marking token as expired...", method, endpoint)
+ // Mark the token as expired and retry, so we trigger the refresh flow
+ a.oauth.SetTokenExpired()
+ h, body, err = a.authorized(ctx, method, endpoint, opts)
+ }
+ // Tokens is invalid we need to renew and authorize again
+ tErr := &eduoauth.TokensInvalidError{}
+ if err != nil && errors.As(err, &tErr) {
+ // Mark the token as invalid and retry, so we trigger the authorization flow
+ a.oauth.SetTokenRenew()
+ log.Logger.Debugf("the tokens were invalid, trying again...")
+ if autherr := a.authorize(ctx); autherr != nil {
+ return nil, nil, autherr
+ }
+ return a.authorized(ctx, method, endpoint, opts)
+ }
+ return h, body, err
+}
+
+func (a *API) Disconnect(ctx context.Context) error {
+ _, _, err := a.authorized(ctx, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second})
+ return err
+}
+
+func (a *API) Info(ctx context.Context) (*profiles.Info, error) {
+ _, body, err := a.authorizedRetry(ctx, http.MethodGet, "/info", nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed API /info: %v", err)
+ }
+ p := profiles.Info{}
+ if err = json.Unmarshal(body, &p); err != nil {
+ return nil, fmt.Errorf("failed API /info: %v", err)
+ }
+ return &p, nil
+}
+
+type Proxy struct {
+ Listen string
+ Peer string
+}
+
+type ConnectData struct {
+ Configuration string
+ Protocol protocol.Protocol
+ Expires time.Time
+ Proxy *wireguard.Proxy
+}
+
+// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1
+func boolToYesNo(preferTCP bool) string {
+ if preferTCP {
+ return "yes"
+ }
+ return "no"
+}
+
+func protocolFromCT(ct string) (protocol.Protocol, error) {
+ switch ct {
+ case "application/x-wireguard-profile":
+ return protocol.WireGuard, nil
+ case "application/x-wireguard+tcp-profile":
+ return protocol.WireGuardTCP, nil
+ case "application/x-openvpn-profile":
+ return protocol.OpenVPN, nil
+ }
+ return protocol.Unknown, fmt.Errorf("invalid content type: %s", ct)
+}
+
+// Connect sends a /connect to an eduVPN server
+// `ctx` is the context used for cancellation
+// protos is the list of protocols supported and wanted by the client
+func (a *API) Connect(ctx context.Context, prof profiles.Profile, protos []protocol.Protocol, pTCP bool) (*ConnectData, error) {
+ hdrs := http.Header{
+ "content-type": {"application/x-www-form-urlencoded"},
+ }
+ uv := url.Values{
+ "profile_id": {prof.ID},
+ }
+
+ if len(protos) == 0 {
+ return nil, errors.New("no protocols supplied")
+ }
+
+ var wgKey *wgtypes.Key
+
+ // Loop over the protocols and set the correct headers and values
+ for _, p := range protos {
+ switch p {
+ case protocol.WireGuard:
+ gk, err := wgtypes.GeneratePrivateKey()
+ if err != nil {
+ return nil, err
+ }
+ wgKey = &gk
+ // Set the public key
+ pubkey := wgKey.PublicKey()
+ uv.Set("public_key", pubkey.String())
+ if !pTCP {
+ hdrs.Add("accept", "application/x-wireguard-profile")
+ }
+ hdrs.Add("accept", "application/x-wireguard+tcp-profile")
+ case protocol.OpenVPN:
+ hdrs.Add("accept", "application/x-openvpn-profile")
+ default:
+ return nil, errors.New("unknown protocol supplied")
+ }
+ }
+ // set prefer TCP
+ uv.Set("prefer_tcp", boolToYesNo(pTCP))
+
+ // Construct the parameters
+ params := &httpw.OptionalParams{Headers: hdrs, Body: uv}
+ h, body, err := a.authorizedRetry(ctx, http.MethodPost, "/connect", params)
+ if err != nil {
+ return nil, fmt.Errorf("failed API /connect call: %v", err)
+ }
+
+ // Parse expiry
+ expH := h.Get("expires")
+ expT, err := http.ParseTime(expH)
+ if err != nil {
+ return nil, fmt.Errorf("failed parsing expiry time: %v", err)
+ }
+
+ vpnCfg := string(body)
+ // Parse content type
+ contentH := h.Get("content-type")
+ proto, err := protocolFromCT(contentH)
+ if err != nil {
+ return nil, err
+ }
+
+ if proto == protocol.OpenVPN {
+ // ensure scripts are not ran by default by append script-security 0 to the config
+ vpnCfg += "\nscript-security 0"
+ return &ConnectData{
+ Configuration: vpnCfg,
+ Protocol: proto,
+ Expires: expT,
+ }, nil
+ }
+
+ vpnCfg, proxy, err := wireguard.Config(vpnCfg, wgKey, proto == protocol.WireGuardTCP)
+ if err != nil {
+ return nil, err
+ }
+ return &ConnectData{
+ Configuration: vpnCfg,
+ Protocol: proto,
+ Expires: expT,
+ Proxy: proxy,
+ }, nil
+}
+
+func getEndpoints(ctx context.Context, url string) (*endpoints.Endpoints, error) {
+ uStr, err := httpw.JoinURLPath(url, "/.well-known/vpn-user-portal")
+ if err != nil {
+ return nil, err
+ }
+ httpC := httpw.NewClient(nil)
+ _, body, err := httpC.Get(ctx, uStr)
+ if err != nil {
+ return nil, fmt.Errorf("failed getting server endpoints with error: %w", err)
+ }
+
+ ep := endpoints.Endpoints{}
+ if err = json.Unmarshal(body, &ep); err != nil {
+ return nil, fmt.Errorf("failed getting server endpoints with error: %w", err)
+ }
+ err = ep.Validate()
+ if err != nil {
+ return nil, err
+ }
+ return &ep, nil
+}
+
+func refreshEndpoints(ctx context.Context, sd ServerData) (*endpoints.List, *endpoints.List, error) {
+ // Get the endpoints
+ ep, err := getEndpoints(ctx, sd.BaseWK)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // This is a mess but we essentially have to instantiate different endpoints if the authorization base URL is different from the base portal URL
+ // This happens with secure internet when the location is not equal to the home location
+ var epauth *endpoints.Endpoints
+ if sd.BaseAuthWK != sd.BaseWK {
+ oep, err := getEndpoints(ctx, sd.BaseAuthWK)
+ if err != nil {
+ return nil, nil, err
+ }
+ epauth = oep
+ } else {
+ epauth = ep
+ }
+ return &ep.API.V3, &epauth.API.V3, err
+}
+
+type OAuthLogger struct{}
+
+func (ol *OAuthLogger) Logf(msg string, params ...interface{}) {
+ log.Logger.Debugf(msg, params...)
+}
+
+func (ol *OAuthLogger) Log(msg string) {
+ log.Logger.Debugf("%s", msg)
+}
+
+func init() {
+ eduoauth.UpdateLogger(&OAuthLogger{})
+}
diff --git a/internal/api/endpoints/endpoints.go b/internal/api/endpoints/endpoints.go
new file mode 100644
index 0000000..11e244b
--- /dev/null
+++ b/internal/api/endpoints/endpoints.go
@@ -0,0 +1,51 @@
+package endpoints
+
+import (
+ "fmt"
+ "net/url"
+)
+
+type List struct {
+ API string `json:"api_endpoint"`
+ Authorization string `json:"authorization_endpoint"`
+ Token string `json:"token_endpoint"`
+}
+
+type Versions struct {
+ V2 List `json:"http://eduvpn.org/api#2"`
+ V3 List `json:"http://eduvpn.org/api#3"`
+}
+
+// Endpoints defines the json format for /.well-known/vpn-user-portal".
+type Endpoints struct {
+ API Versions `json:"api"`
+ V string `json:"v"`
+}
+
+// Validate validates the endpoints by parsing them and checking the scheme is HTTP
+// An error is returned if they are not valid
+func (e Endpoints) Validate() error {
+ v3 := e.API.V3
+ pAPI, err := url.Parse(v3.API)
+ if err != nil {
+ return fmt.Errorf("failed to parse API endpoint: %w", err)
+ }
+ pAuth, err := url.Parse(v3.Authorization)
+ if err != nil {
+ return fmt.Errorf("failed to parse API authorization endpoint: %w", err)
+ }
+ pToken, err := url.Parse(v3.Token)
+ if err != nil {
+ return fmt.Errorf("failed to parse API token endpoint: %w", err)
+ }
+ if pAPI.Scheme != "https" {
+ return fmt.Errorf("API Scheme: '%s', is not equal to HTTPS", pAPI.Scheme)
+ }
+ if pAPI.Scheme != pAuth.Scheme {
+ return fmt.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme)
+ }
+ if pAPI.Scheme != pToken.Scheme {
+ return fmt.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme)
+ }
+ return nil
+}
diff --git a/internal/api/profiles/profiles.go b/internal/api/profiles/profiles.go
new file mode 100644
index 0000000..2f4fed7
--- /dev/null
+++ b/internal/api/profiles/profiles.go
@@ -0,0 +1,82 @@
+package profiles
+
+import (
+ "github.com/eduvpn/eduvpn-common/types/protocol"
+ "github.com/eduvpn/eduvpn-common/types/server"
+)
+
+type Profile struct {
+ ID string `json:"profile_id"`
+ DisplayName string `json:"display_name"`
+ VPNProtoList []string `json:"vpn_proto_list"`
+ DefaultGateway bool `json:"default_gateway"`
+ DNSSearchDomains []string `json:"dns_search_domain_list"`
+}
+
+type ListInfo struct {
+ ProfileList []Profile `json:"profile_list"`
+}
+
+type Info struct {
+ Info ListInfo `json:"info"`
+}
+
+func (i Info) Len() int {
+ return len(i.Info.ProfileList)
+}
+
+func (i Info) Get(id string) *Profile {
+ for _, p := range i.Info.ProfileList {
+ if p.ID == id {
+ return &p
+ }
+ }
+ return nil
+}
+
+func (i Info) MustIndex(n int) Profile {
+ return i.Info.ProfileList[n]
+}
+
+func hasProtocol(protos []string, proto protocol.Protocol) bool {
+ for _, p := range protos {
+ if protocol.New(p) == proto {
+ return true
+ }
+ }
+ return false
+}
+
+func (p *Profile) HasOpenVPN() bool {
+ return hasProtocol(p.VPNProtoList, protocol.OpenVPN)
+}
+
+func (p *Profile) HasWireGuard() bool {
+ return hasProtocol(p.VPNProtoList, protocol.WireGuard)
+}
+
+func (i Info) FilterWireGuard() *Info {
+ var ret []Profile
+ for _, p := range i.Info.ProfileList {
+ if !p.HasOpenVPN() {
+ continue
+ }
+ }
+ return &Info{
+ Info: ListInfo{
+ ProfileList: ret,
+ },
+ }
+}
+
+func (i Info) Public() server.Profiles {
+ m := make(map[string]server.Profile)
+ for _, p := range i.Info.ProfileList {
+ m[p.ID] = server.Profile{
+ DisplayName: map[string]string{
+ "en": p.DisplayName,
+ },
+ }
+ }
+ return server.Profiles{Map: m}
+}
diff --git a/internal/api/redirect.go b/internal/api/redirect.go
new file mode 100644
index 0000000..5d9e749
--- /dev/null
+++ b/internal/api/redirect.go
@@ -0,0 +1,23 @@
+package api
+
+// customRedirects supplies redirect values that should be handled by the app itself
+// here we hardcode the redirect values that we should use in the OAuth requests
+// these values were taken from https://git.sr.ht/~fkooman/vpn-user-portal/tree/v3/item/src/OAuth/VpnClientDb.php
+var customRedirects = map[string]string{
+ "org.letsconnect-vpn.app.ios": "org.letsconnect-vpn.app.ios:/api/callback",
+ "org.letsconnect-vpn.app.android": "org.letsconnect-vpn.app:/api/callback",
+ "org.eduvpn.app.ios": "org.eduvpn.app.ios:/api/callback",
+ "org.eduvpn.app.android": "org.eduvpn.app:/api/callback",
+}
+
+// customRedirect returns the custom redirect string for the clientID `cid`
+// Empty string if none is defined or one is defined but is empty.
+// In both empty string cases, eduvpn-common handles the redirects as 127.0.0.1 local server redirects
+// If a non-empty string is returned, the redirect should be handled by the client and we only use the redirect URI value in our OAuth requests
+func customRedirect(cid string) string {
+ v, ok := customRedirects[cid]
+ if !ok {
+ return ""
+ }
+ return v
+}
diff --git a/internal/server/custom.go b/internal/server/custom.go
new file mode 100644
index 0000000..b4b81cb
--- /dev/null
+++ b/internal/server/custom.go
@@ -0,0 +1,63 @@
+package server
+
+import (
+ "context"
+ "time"
+
+ "github.com/eduvpn/eduvpn-common/internal/api"
+ "github.com/eduvpn/eduvpn-common/internal/config/v2"
+ "github.com/eduvpn/eduvpn-common/types/server"
+ "github.com/jwijenbergh/eduoauth-go"
+)
+
+func (s *Servers) AddCustom(ctx context.Context, id string, na bool) (*Server, error) {
+ sd := api.ServerData{
+ ID: id,
+ Type: server.TypeCustom,
+ BaseWK: id,
+ BaseAuthWK: id,
+ }
+
+ var a *api.API
+ var err error
+ if !na {
+ // Authorize by creating the API object
+ a, err = api.NewAPI(ctx, s.clientID, sd, s.cb, nil)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ err = s.config.AddServer(id, server.TypeCustom, v2.Server{LastAuthorizeTime: time.Now()})
+ if err != nil {
+ return nil, err
+ }
+
+ cust := s.NewServer(id, server.TypeCustom, a)
+ // Return the server with the API
+
+ return &cust, nil
+}
+
+func (s *Servers) GetCustom(ctx context.Context, id string, tok *eduoauth.Token, disableAuth bool) (*Server, error) {
+ sd := api.ServerData{
+ ID: id,
+ Type: server.TypeCustom,
+ BaseWK: id,
+ BaseAuthWK: id,
+ DisableAuthorize: disableAuth,
+ }
+
+ // Get the server from the config
+ _, err := s.config.GetServer(id, server.TypeCustom)
+ if err != nil {
+ return nil, err
+ }
+ a, err := api.NewAPI(ctx, s.clientID, sd, s.cb, tok)
+ if err != nil {
+ return nil, err
+ }
+
+ cust := s.NewServer(id, server.TypeCustom, a)
+ return &cust, nil
+}
diff --git a/internal/server/institute.go b/internal/server/institute.go
new file mode 100644
index 0000000..881f96d
--- /dev/null
+++ b/internal/server/institute.go
@@ -0,0 +1,73 @@
+package server
+
+import (
+ "context"
+ "time"
+
+ "github.com/eduvpn/eduvpn-common/internal/api"
+ "github.com/eduvpn/eduvpn-common/internal/config/v2"
+ "github.com/eduvpn/eduvpn-common/internal/discovery"
+ "github.com/eduvpn/eduvpn-common/types/server"
+ "github.com/jwijenbergh/eduoauth-go"
+)
+
+func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, id string, na bool) (*Server, error) {
+ // This is basically done to double check if the server is part of the institute access section of disco
+ dsrv, err := disco.ServerByURL(id, "institute_access")
+ if err != nil {
+ return nil, err
+ }
+
+ sd := api.ServerData{
+ ID: dsrv.BaseURL,
+ Type: server.TypeInstituteAccess,
+ BaseWK: dsrv.BaseURL,
+ BaseAuthWK: dsrv.BaseURL,
+ }
+
+ var a *api.API
+ if !na {
+ // Authorize by creating the API object
+ a, err = api.NewAPI(ctx, s.clientID, sd, s.cb, nil)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ err = s.config.AddServer(dsrv.BaseURL, server.TypeInstituteAccess, v2.Server{LastAuthorizeTime: time.Now()})
+ if err != nil {
+ return nil, err
+ }
+
+ inst := s.NewServer(dsrv.BaseURL, server.TypeInstituteAccess, a)
+ return &inst, nil
+}
+
+func (s *Servers) GetInstitute(ctx context.Context, id string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) {
+ // This is basically done to double check if the server is part of the institute access section of disco
+ dsrv, err := disco.ServerByURL(id, "institute_access")
+ if err != nil {
+ return nil, err
+ }
+
+ // Get the server from the config
+ _, err = s.config.GetServer(dsrv.BaseURL, server.TypeInstituteAccess)
+ if err != nil {
+ return nil, err
+ }
+ sd := api.ServerData{
+ ID: dsrv.BaseURL,
+ Type: server.TypeInstituteAccess,
+ BaseWK: dsrv.BaseURL,
+ BaseAuthWK: dsrv.BaseURL,
+ DisableAuthorize: disableAuth,
+ }
+ // Authorize by creating the API object
+ a, err := api.NewAPI(ctx, s.clientID, sd, s.cb, tok)
+ if err != nil {
+ return nil, err
+ }
+
+ inst := s.NewServer(dsrv.BaseURL, server.TypeInstituteAccess, a)
+ return &inst, nil
+}
diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go
new file mode 100644
index 0000000..19e75a1
--- /dev/null
+++ b/internal/server/secureinternet.go
@@ -0,0 +1,91 @@
+package server
+
+import (
+ "context"
+ "errors"
+ "time"
+
+ "github.com/eduvpn/eduvpn-common/internal/api"
+ "github.com/eduvpn/eduvpn-common/internal/config/v2"
+ "github.com/eduvpn/eduvpn-common/internal/discovery"
+ "github.com/eduvpn/eduvpn-common/internal/util"
+ "github.com/eduvpn/eduvpn-common/types/server"
+ "github.com/jwijenbergh/eduoauth-go"
+)
+
+func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, orgID string, na bool) (*Server, error) {
+ if s.config.HasSecureInternet() {
+ return nil, errors.New("a secure internet server already exists")
+ }
+ dorg, dsrv, err := disco.SecureHomeArgs(orgID)
+ if err != nil {
+ // We mark the organizations as expired because we got an error
+ // Note that in the docs it states that it only should happen when the Org ID doesn't exist
+ // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found
+ disco.MarkOrganizationsExpired()
+ return nil, err
+ }
+
+ sd := api.ServerData{
+ ID: orgID,
+ Type: server.TypeSecureInternet,
+ BaseWK: dsrv.BaseURL,
+ BaseAuthWK: dsrv.BaseURL,
+ ProcessAuth: func(url string) string {
+ return util.ReplaceWAYF(dsrv.AuthenticationURLTemplate, url, dorg.OrgID)
+ },
+ }
+
+ var a *api.API
+ if !na {
+ // Authorize by creating the API object
+ a, err = api.NewAPI(ctx, s.clientID, sd, s.cb, nil)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ err = s.config.AddServer(orgID, server.TypeSecureInternet, v2.Server{CountryCode: dsrv.CountryCode, LastAuthorizeTime: time.Now()})
+ if err != nil {
+ return nil, err
+ }
+
+ sec := s.NewServer(orgID, server.TypeSecureInternet, a)
+ return &sec, nil
+}
+
+func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) {
+ srv, err := s.config.GetServer(orgID, server.TypeSecureInternet)
+ if err != nil {
+ return nil, err
+ }
+
+ dorg, dhome, err := disco.SecureHomeArgs(orgID)
+ if err != nil {
+ return nil, err
+ }
+
+ dloc, err := disco.ServerByCountryCode(srv.CountryCode)
+ if err != nil {
+ return nil, err
+ }
+
+ sd := api.ServerData{
+ ID: dorg.OrgID,
+ Type: server.TypeSecureInternet,
+ BaseWK: dloc.BaseURL,
+ BaseAuthWK: dhome.BaseURL,
+ ProcessAuth: func(url string) string {
+ return util.ReplaceWAYF(dhome.AuthenticationURLTemplate, url, dorg.OrgID)
+ },
+ DisableAuthorize: disableAuth,
+ }
+
+ a, err := api.NewAPI(ctx, s.clientID, sd, s.cb, tok)
+ if err != nil {
+ return nil, err
+ }
+
+ sec := s.NewServer(orgID, server.TypeSecureInternet, a)
+ return &sec, nil
+}
diff --git a/internal/server/server.go b/internal/server/server.go
index 1bdef28..97dafff 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -2,280 +2,227 @@ package server
import (
"context"
+ "errors"
"os"
"time"
- "github.com/eduvpn/eduvpn-common/internal/discovery"
- "github.com/eduvpn/eduvpn-common/internal/oauth"
- "github.com/eduvpn/eduvpn-common/internal/server/api"
- "github.com/eduvpn/eduvpn-common/internal/server/base"
- "github.com/eduvpn/eduvpn-common/internal/server/profile"
- "github.com/eduvpn/eduvpn-common/internal/wireguard"
+ "github.com/eduvpn/eduvpn-common/internal/api"
+ "github.com/eduvpn/eduvpn-common/internal/api/profiles"
+ v2 "github.com/eduvpn/eduvpn-common/internal/config/v2"
"github.com/eduvpn/eduvpn-common/types/protocol"
srvtypes "github.com/eduvpn/eduvpn-common/types/server"
- "github.com/go-errors/errors"
)
-type Server interface {
- // OAuth returns the struct used for OAuth
- OAuth() *oauth.OAuth
-
- // TemplateAuth returns the authorization URL template function
- TemplateAuth() func(string) string
-
- // Base returns the server base
- Base() (*base.Base, error)
-
- // NeedsLocation checks if the server needs a secure internet location
- NeedsLocation() bool
-
- // Public returns the representation that will be passed over the CGO barrier
- Public() (interface{}, error)
-
- // RefreshEndpoints refreshes the endpoints for the server
- RefreshEndpoints(context.Context, *discovery.Discovery) error
+type Server struct {
+ identifier string
+ t srvtypes.Type
+ apiw *api.API
+ storage *v2.V2
}
-// Name gets the name for the server and falls back to a default of "Unknown Server"
-func Name(srv Server) string {
- n := "Unknown Server"
- if b, err := srv.Base(); err == nil {
- n = b.URL
- }
- return n
-}
-
-func UpdateTokens(srv Server, t oauth.Token) {
- srv.OAuth().UpdateTokens(t)
-}
-
-func OAuthURL(srv Server, name string, cr string) (string, error) {
- return srv.OAuth().AuthURL(name, srv.TemplateAuth(), cr)
-}
+var ErrInvalidProfile = errors.New("invalid profile")
-func OAuthExchange(ctx context.Context, srv Server, uri string) error {
- return srv.OAuth().Exchange(ctx, uri)
-}
-
-func HeaderToken(ctx context.Context, srv Server) (string, error) {
- return srv.OAuth().AccessToken(ctx)
-}
-
-func MarkTokenExpired(srv Server) {
- srv.OAuth().SetTokenExpired()
-}
-
-func MarkTokensForRenew(srv Server) {
- srv.OAuth().SetTokenRenew()
-}
-
-func NeedsRelogin(ctx context.Context, srv Server) bool {
- // TODO: this error can be a context cancel
- _, err := HeaderToken(ctx, srv)
- return err != nil
+func (s *Servers) NewServer(identifier string, t srvtypes.Type, api *api.API) Server {
+ return Server{
+ identifier: identifier,
+ t: t,
+ apiw: api,
+ storage: s.config,
+ }
}
-func CurrentProfile(srv Server) (*profile.Profile, error) {
- b, err := srv.Base()
+// Profiles gets the profiles for the server
+// It only does a /info network request if the profiles have not been cached
+// force indicates whether or not the profiles should be fetched fresh
+func (s *Server) Profiles(ctx context.Context) (*profiles.Info, error) {
+ a, err := s.api()
if err != nil {
return nil, err
}
- pID := b.Profiles.Current
- for _, profile := range b.Profiles.Info.ProfileList {
- if profile.ID == pID {
- return &profile, nil
- }
+ // Otherwise get fresh profiles and set the cache
+ prfs, err := a.Info(ctx)
+ if err != nil {
+ return nil, err
}
-
- return nil, errors.Errorf("profile not found: " + pID)
-}
-
-func ValidProfiles(srv Server, wireguardSupport bool) (*profile.Info, error) {
- // No error wrapping here otherwise we wrap it too much
- b, err := srv.Base()
+ err = s.SetProfileList(prfs.Public())
if err != nil {
return nil, err
}
- ps := b.Profiles.Supported(wireguardSupport)
- if len(ps) == 0 {
- return nil, errors.Errorf("no profiles found with supported protocols")
+ return prfs, nil
+}
+
+func (s *Server) api() (*api.API, error) {
+ if s.apiw == nil {
+ return nil, errors.New("no API object found")
}
- return &profile.Info{
- Current: b.Profiles.Current,
- Info: profile.ListInfo{
- ProfileList: ps,
- },
- }, nil
+ return s.apiw, nil
}
-func Profile(srv Server, id string) error {
- b, err := srv.Base()
+func (s *Server) findProfile(ctx context.Context, wgSupport bool) (*profiles.Profile, error) {
+ // Get the profiles by ignoring the cache
+ prfs, err := s.Profiles(ctx)
if err != nil {
- return err
+ return nil, err
}
- if !b.Profiles.Has(id) {
- return errors.Errorf("no profile available with id: %s", id)
+
+ // No profiles available
+ if prfs.Len() == 0 {
+ return nil, errors.New("the server has no available profiles for your account")
}
- b.Profiles.Current = id
- return nil
-}
-type ConfigData struct {
- // The configuration
- Config string
+ // No WireGuard support, we have to filter the profiles that only have WireGuard
+ if !wgSupport {
+ prfs = prfs.FilterWireGuard()
+ }
- // The type of configuration
- Type string
-}
+ var chosenP profiles.Profile
-// Public gets the public data from the types package
-// dg specifies if this config is default gateway
-func (c *ConfigData) Public(dg bool) srvtypes.Configuration {
- return srvtypes.Configuration{
- VPNConfig: c.Config,
- Protocol: protocol.New(c.Type),
- DefaultGateway: dg,
+ n := prfs.Len()
+ switch n {
+ // If we now get no profiles then that means a profile with only WireGuard was removed
+ case 0:
+ return nil, errors.New("the server has only WireGuard profiles but the client does not support WireGuard")
+ case 1:
+ // Only one profile, make sure it is set
+ chosenP = prfs.MustIndex(0)
+ default:
+ // Profile doesn't exist
+ prID, err := s.ProfileID()
+ if err != nil {
+ return nil, err
+ }
+ v := prfs.Get(prID)
+ if v == nil {
+ return nil, ErrInvalidProfile
+ }
+ chosenP = *v
}
+ return &chosenP, nil
}
-func wireguardGetConfig(ctx context.Context, srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) {
- b, err := srv.Base()
+func (s *Server) connect(ctx context.Context, wgSupport bool, pTCP bool) (*srvtypes.Configuration, error) {
+ a, err := s.api()
if err != nil {
return nil, err
}
- pID := b.Profiles.Current
- key, err := wireguard.GenerateKey()
+ // find a suitable profile to connect
+ chosenP, err := s.findProfile(ctx, wgSupport)
if err != nil {
return nil, err
}
-
- pub := key.PublicKey().String()
- cfg, proto, exp, err := api.ConnectWireguard(ctx, b, srv.OAuth(), pID, pub, preferTCP, openVPNSupport)
+ err = s.SetProfileID(chosenP.ID)
if err != nil {
return nil, err
}
- // Store start and end time
- b.StartTime = time.Now()
- b.EndTime = exp
-
- if proto == "wireguard" {
- // This needs the go code a way to identify a connection
- // Use the uuid of the connection e.g. on Linux
- // This needs the client code to call the go code
-
- cfg = wireguard.ConfigAddKey(cfg, key)
+ protos := []protocol.Protocol{protocol.OpenVPN}
+ if wgSupport {
+ protos = append(protos, protocol.WireGuard)
}
-
- return &ConfigData{Config: cfg, Type: proto}, nil
-}
-
-func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigData, error) {
- b, err := srv.Base()
+ // If the client supports WireGuard and the profile supports both protocols we remove openvpn from client support if EDUVPN_PREFER_WG is set to "1"
+ // This also only happens if prefer TCP is set to false
+ // TODO: remove the prefer TCP check when we have implemented proxyguard
+ if wgSupport && os.Getenv("EDUVPN_PREFER_WG") == "1" {
+ if chosenP.HasWireGuard() && chosenP.HasOpenVPN() {
+ protos = []protocol.Protocol{protocol.WireGuard}
+ }
+ }
+ // SAFETY: chosenP is guaranteed to be non-nil
+ apicfg, err := a.Connect(ctx, *chosenP, protos, pTCP)
if err != nil {
return nil, err
}
- pid := b.Profiles.Current
- cfg, exp, err := api.ConnectOpenVPN(ctx, b, srv.OAuth(), pid, preferTCP)
+ err = s.SetExpireTime(apicfg.Expires)
if err != nil {
return nil, err
}
-
- // Store start and end time
- b.StartTime = time.Now()
- b.EndTime = exp
-
- return &ConfigData{Config: cfg, Type: "openvpn"}, nil
+ var proxy *srvtypes.Proxy
+ if apicfg.Proxy != nil {
+ proxy = &srvtypes.Proxy{
+ SourcePort: apicfg.Proxy.SourcePort,
+ Listen: apicfg.Proxy.Listen,
+ Peer: apicfg.Proxy.Peer,
+ }
+ }
+ return &srvtypes.Configuration{
+ VPNConfig: apicfg.Configuration,
+ Protocol: apicfg.Protocol,
+ DefaultGateway: chosenP.DefaultGateway,
+ DNSSearchDomains: chosenP.DNSSearchDomains,
+ Proxy: proxy,
+ }, nil
}
-func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bool, error) {
- b, err := srv.Base()
+func (s *Server) Disconnect(ctx context.Context) error {
+ a, err := s.api()
if err != nil {
- return false, err
- }
- // Get new profiles using the info call
- // This does not override the current profile
- err = api.Info(ctx, b, srv.OAuth())
- if err != nil {
- return false, err
+ return err
}
+ return a.Disconnect(ctx)
+}
- // If there was a profile chosen and it doesn't exist anymore, reset it
- if b.Profiles.Current != "" {
- if _, err = CurrentProfile(srv); err != nil {
- b.Profiles.Current = ""
- }
+func (s *Server) cfgServer() (*v2.Server, error) {
+ if s.storage == nil {
+ return nil, errors.New("cannot get server, no configuration passed")
}
+ return s.storage.GetServer(s.identifier, s.t)
+}
- // there are multiple profiles and no selection has been made
- if len(b.Profiles.Info.ProfileList) != 1 && b.Profiles.Current == "" {
- return false, nil
+func (s *Server) SetProfileID(id string) error {
+ cs, err := s.cfgServer()
+ if err != nil {
+ return err
}
+ cs.Profiles.Current = id
+ return nil
+}
- // Set the current profile if there is only one profile or profile is already selected
- // Set the first profile if none is selected
- if b.Profiles.Current == "" {
- b.Profiles.Current = b.Profiles.Info.ProfileList[0].ID
- }
- p, err := CurrentProfile(srv)
- // shouldn't happen
+func (s *Server) SetProfileList(prfs srvtypes.Profiles) error {
+ cs, err := s.cfgServer()
if err != nil {
- return false, err
- }
- // Profile does not support OpenVPN but the client also doesn't support WireGuard
- if !p.SupportsOpenVPN() && !wireguardSupport {
- return false, nil
+ return err
}
- return true, nil
+ cs.Profiles.Map = prfs.Map
+ return nil
}
-func Config(ctx context.Context, server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) {
- p, err := CurrentProfile(server)
+func (s *Server) SetExpireTime(et time.Time) error {
+ cs, err := s.cfgServer()
if err != nil {
- return nil, err
+ return err
}
+ cs.ExpireTime = et
+ return nil
+}
- ovpn := p.SupportsOpenVPN()
- wg := p.SupportsWireguard() && wireguardSupport
-
- // If we don't prefer TCP and this profile and client supports wireguard,
- // we disable openvpn if the EDUVPN_PREFER_WG environment variable is set
- // This is useful to force WireGuard if the profile supports both OpenVPN and WireGuard but the server still prefers OpenVPN
- if !preferTCP && wg {
- if os.Getenv("EDUVPN_PREFER_WG") == "1" {
- ovpn = false
- }
+func (s *Server) ProfileID() (string, error) {
+ cs, err := s.cfgServer()
+ if err != nil {
+ return "", err
}
+ return cs.Profiles.Current, nil
+}
- var cfg *ConfigData
-
- switch {
- // The config supports wireguard and optionally openvpn
- case wg:
- // A wireguard connect call needs to generate a wireguard key and add it to the config
- // Also the server could send back an OpenVPN config if it supports OpenVPN
- cfg, err = wireguardGetConfig(ctx, server, preferTCP, ovpn)
- // The config only supports OpenVPN
- case ovpn:
- cfg, err = openVPNGetConfig(ctx, server, preferTCP)
- // The config supports no available protocol because the profile only supports WireGuard but the client doesn't
- default:
- return nil, errors.New("no supported protocol found")
+func (s *Server) SetLocation(loc string) error {
+ if s.t != srvtypes.TypeSecureInternet {
+ return errors.New("changing secure internet location is only possible when the server is a secure location")
}
-
- // Add script security 0 to disable OpenVPN scripts
- // The client may override this but we provide the default protection here
- if err == nil && cfg.Type == "openvpn" {
- cfg.Config += "\nscript-security 0"
+ cs, err := s.cfgServer()
+ if err != nil {
+ return err
}
- return cfg, err
+ cs.CountryCode = loc
+ return nil
}
-func Disconnect(ctx context.Context, server Server) error {
- b, err := server.Base()
- if err != nil {
- return err
+func (s *Server) SetCurrent() error {
+ if s.storage == nil {
+ return errors.New("no storage available")
}
- return api.Disconnect(ctx, b, server.OAuth())
+ s.storage.LastChosen = &v2.ServerType{
+ ID: s.identifier,
+ T: s.t,
+ }
+ return nil
}
diff --git a/internal/server/servers.go b/internal/server/servers.go
new file mode 100644
index 0000000..fe2550c
--- /dev/null
+++ b/internal/server/servers.go
@@ -0,0 +1,113 @@
+package server
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/eduvpn/eduvpn-common/internal/api"
+ "github.com/eduvpn/eduvpn-common/internal/config/v2"
+ "github.com/eduvpn/eduvpn-common/internal/discovery"
+ srvtypes "github.com/eduvpn/eduvpn-common/types/server"
+ "github.com/jwijenbergh/eduoauth-go"
+)
+
+type Callbacks interface {
+ api.Callbacks
+ GettingConfig() error
+ InvalidProfile(context.Context, *Server) (string, error)
+}
+
+type Servers struct {
+ clientID string
+ cb Callbacks
+ WGSupport bool
+ config *v2.V2
+}
+
+func (s *Servers) Remove(identifier string, t srvtypes.Type) error {
+ return s.config.RemoveServer(identifier, t)
+}
+
+func NewServers(name string, cb Callbacks, wgSupport bool, cfg *v2.V2) Servers {
+ return Servers{
+ clientID: name,
+ cb: cb,
+ WGSupport: wgSupport,
+ config: cfg,
+ }
+}
+
+type CurrentServer struct {
+ *v2.Server
+ T v2.ServerType
+ srvs *Servers
+}
+
+func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, disco *discovery.Discovery, tokens *eduoauth.Token, disableAuth bool) (*Server, error) {
+ switch cs.T.T {
+ case srvtypes.TypeInstituteAccess:
+ return cs.srvs.GetInstitute(ctx, cs.T.ID, disco, tokens, disableAuth)
+ case srvtypes.TypeSecureInternet:
+ return cs.srvs.GetSecure(ctx, cs.T.ID, disco, tokens, disableAuth)
+ case srvtypes.TypeCustom:
+ return cs.srvs.GetCustom(ctx, cs.T.ID, tokens, disableAuth)
+ default:
+ return nil, fmt.Errorf("no such server type: %d", cs.T.T)
+ }
+}
+
+func (s *Servers) GetServer(id string, t srvtypes.Type) (*v2.Server, error) {
+ if s.config == nil {
+ return nil, errors.New("no configuration available")
+ }
+ return s.config.GetServer(id, t)
+}
+
+func (s *Servers) CurrentServer() (*CurrentServer, error) {
+ curr, k, err := s.config.CurrentServer()
+ if err != nil {
+ return nil, err
+ }
+ return &CurrentServer{
+ Server: curr,
+ T: *k,
+ srvs: s,
+ }, nil
+}
+
+func (s *Servers) PublicCurrent(disco *discovery.Discovery) (*srvtypes.Current, error) {
+ return s.config.PublicCurrent(disco)
+}
+
+func (s *Servers) ConnectWithCallbacks(ctx context.Context, srv *Server, pTCP bool) (*srvtypes.Configuration, error) {
+ err := srv.SetCurrent()
+ if err != nil {
+ return nil, err
+ }
+ err = s.cb.GettingConfig()
+ if err != nil {
+ return nil, err
+ }
+ cfg, err := srv.connect(ctx, s.WGSupport, pTCP)
+ if err == nil {
+ return cfg, nil
+ }
+ if !errors.Is(err, ErrInvalidProfile) {
+ return cfg, err
+ }
+ // Get a new profile from the callback
+ pr, err := s.cb.InvalidProfile(ctx, srv)
+ if err != nil {
+ return cfg, err
+ }
+ err = srv.SetProfileID(pr)
+ if err != nil {
+ return nil, err
+ }
+ err = s.cb.GettingConfig()
+ if err != nil {
+ return nil, err
+ }
+ return srv.connect(ctx, s.WGSupport, pTCP)
+}
diff --git a/internal/server/time.go b/internal/server/time.go
new file mode 100644
index 0000000..4e54ef6
--- /dev/null
+++ b/internal/server/time.go
@@ -0,0 +1,70 @@
+package server
+
+import "time"
+
+// RenewButtonTime returns the time when the renew button should be shown for the server
+// Implemented according to: https://github.com/eduvpn/documentation/blob/cdf4d054f7652d74e4192494e8bb0e21040e46ac/API.md#session-expiry
+func RenewButtonTime(st time.Time, et time.Time) int64 {
+ d := et.Sub(st)
+
+ // If the time is less than 24 hours (a day left), show it when 30 minutes have passed or on expired if less than 30 minutes
+ dayl := time.Duration(24 * time.Hour)
+ if d < dayl {
+ // Get the minimum time to add, 30 minutes or on expired
+ m := time.Duration(30 * time.Minute)
+ // The total delta time is larger, return that we should show the button after 30 minutes
+ if d > m {
+ return st.Add(30 * time.Minute).Unix()
+ }
+ // Just show it on expired
+ return st.Add(d).Unix()
+ }
+
+ // Else just show it when 24 hours is left
+ // This is the delta minus 24 hours left as that's how long it takes for a day to be left in the expiry
+ // We thus add this to the start time
+ tillDay := d - dayl
+ t := st.Add(tillDay)
+ return t.Unix()
+}
+
+func CountdownTime(st time.Time, et time.Time) int64 {
+ d := et.Sub(st)
+
+ dayl := time.Duration(24 * time.Hour)
+
+ // This is just the last 24 hours
+ // if less than or equal to 24 hours, immediately
+ if d <= dayl {
+ return st.Unix()
+ }
+
+ tillDay := d - dayl
+ t := st.Add(tillDay)
+ return t.Unix()
+}
+
+func NotificationTimes(st time.Time, et time.Time) []int64 {
+ last := []time.Duration{
+ time.Duration(0),
+ time.Duration(1 * time.Hour),
+ time.Duration(2 * time.Hour),
+ time.Duration(4 * time.Hour),
+ }
+
+ var t []int64
+
+ d := et.Sub(st)
+ for _, l := range last {
+ // If the notification remaining time is more than the total delta, continue
+ if l > d {
+ continue
+ }
+ // calculating the time till a notification must happen
+ tillN := d - l
+ // Get absolute time when this notification must be shown by adding the delta
+ c := st.Add(tillN)
+ t = append(t, c.Unix())
+ }
+ return t
+}