diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2023-04-12 22:52:49 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2023-09-25 09:43:37 +0200 |
| commit | a23c3e61c5d89ef67973891b5b3a176c06e1b174 (patch) | |
| tree | f1eed03b047f8affd3d5123fa5c9e868ac7d8bec /internal | |
| parent | ee95eb45708e1fa766a63866d26d05d13f23e8c9 (diff) | |
Refactor: Split internal server into multiple packages
- Pass contexts
- Have separate packages for e.g. custom, institute and secure
- internet servers, profiles....
- Return types from the public ./types package with a Public() method
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/server/api/api.go (renamed from internal/server/api.go) | 120 | ||||
| -rw-r--r-- | internal/server/api/api_test.go (renamed from internal/server/api_test.go) | 38 | ||||
| -rw-r--r-- | internal/server/base/base.go (renamed from internal/server/base.go) | 48 | ||||
| -rw-r--r-- | internal/server/custom.go | 35 | ||||
| -rw-r--r-- | internal/server/custom/custom.go | 31 | ||||
| -rw-r--r-- | internal/server/endpoints/endpoints.go | 53 | ||||
| -rw-r--r-- | internal/server/institute/institute.go | 106 | ||||
| -rw-r--r-- | internal/server/instituteaccess.go | 114 | ||||
| -rw-r--r-- | internal/server/list.go | 179 | ||||
| -rw-r--r-- | internal/server/profile.go | 44 | ||||
| -rw-r--r-- | internal/server/profile/profile.go | 88 | ||||
| -rw-r--r-- | internal/server/profile/profile_test.go (renamed from internal/server/profile_test.go) | 6 | ||||
| -rw-r--r-- | internal/server/secure/secure.go | 148 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 175 | ||||
| -rw-r--r-- | internal/server/server.go | 130 | ||||
| -rw-r--r-- | internal/server/servers.go | 121 |
16 files changed, 763 insertions, 673 deletions
diff --git a/internal/server/api.go b/internal/server/api/api.go index 546c02a..9ad6f2d 100644 --- a/internal/server/api.go +++ b/internal/server/api/api.go @@ -1,6 +1,7 @@ -package server +package api import ( + "context" "encoding/json" "fmt" "net/http" @@ -10,65 +11,43 @@ import ( httpw "github.com/eduvpn/eduvpn-common/internal/http" "github.com/eduvpn/eduvpn-common/internal/log" + "github.com/eduvpn/eduvpn-common/internal/oauth" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/server/endpoints" + "github.com/eduvpn/eduvpn-common/internal/server/profile" "github.com/go-errors/errors" ) -func validateEndpoints(endpoints Endpoints) error { - v3 := endpoints.API.V3 - pAPI, err := url.Parse(v3.API) +func Endpoints(ctx context.Context, b *base.Base) error { + uStr, err := httpw.JoinURLPath(b.URL, "/.well-known/vpn-user-portal") if err != nil { - return errors.WrapPrefix(err, "failed to parse API endpoint", 0) - } - pAuth, err := url.Parse(v3.Authorization) - if err != nil { - return errors.WrapPrefix(err, "failed to parse API authorization endpoint", 0) - } - pToken, err := url.Parse(v3.Token) - if err != nil { - return errors.WrapPrefix(err, "failed to parse API token endpoint", 0) - } - if pAPI.Scheme != pAuth.Scheme { - return errors.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme) - } - if pAPI.Scheme != pToken.Scheme { - return errors.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme) - } - if pAPI.Host != pAuth.Host { - return errors.Errorf("API host: '%v', is not equal to authorization host: '%v'", pAPI.Host, pAuth.Host) - } - if pAPI.Host != pToken.Host { - return errors.Errorf("API host: '%v', is not equal to token host: '%v'", pAPI.Host, pToken.Host) - } - return nil -} - -func APIGetEndpoints(baseURL string, client *httpw.Client) (*Endpoints, error) { - uStr, err := httpw.JoinURLPath(baseURL, "/.well-known/vpn-user-portal") - if err != nil { - return nil, err + return err } - if client == nil { - client = httpw.NewClient() + if b.HTTPClient == nil { + b.HTTPClient = httpw.NewClient() } - _, body, err := client.Get(uStr) + _, body, err := b.HTTPClient.Get(ctx, uStr) if err != nil { - return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) + return errors.WrapPrefix(err, "failed getting server endpoints", 0) } - ep := Endpoints{} + ep := endpoints.Endpoints{} if err = json.Unmarshal(body, &ep); err != nil { - return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) + return errors.WrapPrefix(err, "failed getting server endpoints", 0) } - err = validateEndpoints(ep) + err = ep.Validate() if err != nil { - return nil, err + return err } - return &ep, nil + b.Endpoints = ep + return nil } -func apiAuthorized( - srv Server, +func authorized( + ctx context.Context, + b *base.Base, + oauth *oauth.OAuth, method string, endpoint string, opts *httpw.OptionalParams, @@ -78,10 +57,6 @@ func apiAuthorized( opts = &httpw.OptionalParams{} } errorMessage := "failed API authorized" - b, err := srv.Base() - if err != nil { - return nil, nil, errors.WrapPrefix(err, errorMessage, 0) - } // Join the paths u, err := url.Parse(b.Endpoints.API.V3.API) @@ -91,7 +66,7 @@ func apiAuthorized( u.Path = path.Join(u.Path, endpoint) // Make sure the tokens are valid, this will return an error if re-login is needed - t, err := HeaderToken(srv) + t, err := oauth.AccessToken(ctx) if err != nil { return nil, nil, errors.WrapPrefix(err, errorMessage, 0) } @@ -105,19 +80,21 @@ func apiAuthorized( } // Create a client if it doesn't exist - if b.httpClient == nil { - b.httpClient = httpw.NewClient() + if b.HTTPClient == nil { + b.HTTPClient = httpw.NewClient() } - return b.httpClient.Do(method, u.String(), opts) + return b.HTTPClient.Do(ctx, method, u.String(), opts) } -func apiAuthorizedRetry( - srv Server, +func authorizedRetry( + ctx context.Context, + b *base.Base, + auth *oauth.OAuth, method string, endpoint string, opts *httpw.OptionalParams, ) (http.Header, []byte, error) { - h, body, err := apiAuthorized(srv, method, endpoint, opts) + h, body, err := authorized(ctx, b, auth, method, endpoint, opts) if err == nil { return h, body, nil } @@ -127,27 +104,22 @@ func apiAuthorizedRetry( if errors.As(err, &statErr) && statErr.Status == 401 { log.Logger.Debugf("Got a 401 error after HTTP method: %s, endpoint: %s. Marking token as expired...", method, endpoint) // Mark the token as expired and retry, so we trigger the refresh flow - MarkTokenExpired(srv) - h, body, err = apiAuthorized(srv, method, endpoint, opts) + auth.SetTokenExpired() + h, body, err = authorized(ctx, b, auth, method, endpoint, opts) } return h, body, err } -func APIInfo(srv Server) error { - _, body, err := apiAuthorizedRetry(srv, http.MethodGet, "/info", nil) +func Info(ctx context.Context, b *base.Base, auth *oauth.OAuth) error { + _, body, err := authorizedRetry(ctx, b, auth, http.MethodGet, "/info", nil) if err != nil { return err } - profiles := ProfileInfo{} + profiles := profile.Info{} if err = json.Unmarshal(body, &profiles); err != nil { return errors.WrapPrefix(err, "failed API /info", 0) } - b, err := srv.Base() - if err != nil { - return err - } - // Store the profiles and make sure that the current profile is not overwritten prev := b.Profiles.Current b.Profiles = profiles @@ -163,8 +135,10 @@ func boolToYesNo(preferTCP bool) string { return "no" } -func APIConnectWireguard( - srv Server, +func ConnectWireguard( + ctx context.Context, + b *base.Base, + auth *oauth.OAuth, profileID string, pubkey string, preferTCP bool, @@ -186,7 +160,7 @@ func APIConnectWireguard( "public_key": {pubkey}, "prefer_tcp": {boolToYesNo(preferTCP)}, } - h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect", + h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect", &httpw.OptionalParams{Headers: hdrs, Body: vals}) if err != nil { return "", "", time.Time{}, err @@ -208,7 +182,7 @@ func APIConnectWireguard( return string(body), content, expTime, nil } -func APIConnectOpenVPN(srv Server, profileID string, preferTCP bool) (string, time.Time, error) { +func ConnectOpenVPN(ctx context.Context, b *base.Base, auth *oauth.OAuth, profileID string, preferTCP bool) (string, time.Time, error) { hdrs := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-openvpn-profile"}, @@ -219,7 +193,7 @@ func APIConnectOpenVPN(srv Server, profileID string, preferTCP bool) (string, ti "prefer_tcp": {boolToYesNo(preferTCP)}, } - h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect", + h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect", &httpw.OptionalParams{Headers: hdrs, Body: vals}) if err != nil { return "", time.Time{}, err @@ -234,10 +208,10 @@ func APIConnectOpenVPN(srv Server, profileID string, preferTCP bool) (string, ti return string(body), expT, nil } -// APIDisconnect disconnects from the API. -func APIDisconnect(server Server) error { +// Disconnect disconnects the VPN using the API. +func Disconnect(ctx context.Context, b *base.Base, auth *oauth.OAuth) error { // The timeout is a bit lower here such that this does not take a too long time for disconnecting // Clients may wish to retry this - _, _, err := apiAuthorized(server, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second}) + _, _, err := authorized(ctx, b, auth, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second}) return err } diff --git a/internal/server/api_test.go b/internal/server/api/api_test.go index b1e3550..7509a30 100644 --- a/internal/server/api_test.go +++ b/internal/server/api/api_test.go @@ -1,11 +1,14 @@ -package server +package api import ( + "context" "encoding/json" "fmt" "net/http" "testing" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/server/endpoints" "github.com/eduvpn/eduvpn-common/internal/test" "github.com/go-errors/errors" ) @@ -17,7 +20,7 @@ func getErrorMsg(err error) string { return err.Error() } -func compareEndpoints(ep1 Endpoints, ep2 Endpoints) bool { +func compareEndpoints(ep1 endpoints.Endpoints, ep2 endpoints.Endpoints) bool { v3_1 := ep1.API.V3 v3_2 := ep2.API.V3 return v3_1.API == v3_2.API && v3_1.Authorization == v3_2.Authorization && v3_1.Token == v3_2.Token @@ -38,11 +41,11 @@ func Test_APIGetEndpoints(t *testing.T) { } testCases := []struct { - epl EndpointList + epl endpoints.List err error }{ { - epl: EndpointList{ + epl: endpoints.List{ API: "https://example.com/1", Authorization: "https://example.com/2", Token: "https://example.com/3", @@ -50,7 +53,7 @@ func Test_APIGetEndpoints(t *testing.T) { err: nil, }, { - epl: EndpointList{ + epl: endpoints.List{ API: "http://example.com/1", Authorization: "https://example.com/2", Token: "https://example.com/3", @@ -58,7 +61,7 @@ func Test_APIGetEndpoints(t *testing.T) { err: errors.New("API scheme: 'http', is not equal to authorization scheme: 'https'"), }, { - epl: EndpointList{ + epl: endpoints.List{ API: "https://example.com/1", Authorization: "https://example.com/2", Token: "ftp://example.com/3", @@ -66,7 +69,7 @@ func Test_APIGetEndpoints(t *testing.T) { err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"), }, { - epl: EndpointList{ + epl: endpoints.List{ API: "https://malicious.com/1", Authorization: "https://example.com/2", Token: "https://example.com/3", @@ -74,7 +77,7 @@ func Test_APIGetEndpoints(t *testing.T) { err: errors.New("API host: 'malicious.com', is not equal to authorization host: 'example.com'"), }, { - epl: EndpointList{ + epl: endpoints.List{ API: "https://example.com/1", Authorization: "https://example.com/2", Token: "https://malicious.com/3", @@ -82,7 +85,7 @@ func Test_APIGetEndpoints(t *testing.T) { err: errors.New("API host: 'example.com', is not equal to token host: 'malicious.com'"), }, { - epl: EndpointList{ + epl: endpoints.List{ API: "https://example.com/1", Authorization: "https://malicious.com/2", Token: "https://example.com/3", @@ -92,8 +95,8 @@ func Test_APIGetEndpoints(t *testing.T) { } for _, tc := range testCases { - ep := &Endpoints{ - API: EndpointsVersions{ + ep := &endpoints.Endpoints{ + API: endpoints.Versions{ V3: tc.epl, }, } @@ -108,7 +111,11 @@ func Test_APIGetEndpoints(t *testing.T) { fmt.Fprintln(w, string(jsonStr)) })) - gotEP, err := APIGetEndpoints(s.URL, c) + b := &base.Base{ + URL: s.URL, + HTTPClient: c, + } + err = Endpoints(context.Background(), b) if getErrorMsg(err) != getErrorMsg(tc.err) { t.Fatalf("Errors not equal, want err: %v, got: %v", tc.err, err) } @@ -119,12 +126,9 @@ func Test_APIGetEndpoints(t *testing.T) { if ep == nil { t.Fatalf("No test case endpoints") } - if gotEP == nil { - t.Fatalf("Got no endpoints for nil error") - } // if no error then the endpoints should be equal - if !compareEndpoints(*ep, *gotEP) { - t.Fatalf("Endpoints are not equal, got: %v, want: %v", gotEP, ep) + if !compareEndpoints(*ep, b.Endpoints) { + t.Fatalf("Endpoints are not equal, got: %v, want: %v", b.Endpoints, ep) } } } diff --git a/internal/server/base.go b/internal/server/base/base.go index c7a9adc..d483dad 100644 --- a/internal/server/base.go +++ b/internal/server/base/base.go @@ -1,47 +1,25 @@ -package server +package base import ( "time" "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/internal/server/endpoints" + "github.com/eduvpn/eduvpn-common/internal/server/profile" + "github.com/eduvpn/eduvpn-common/types/server" ) // Base is the base type for servers. type Base struct { - URL string `json:"base_url"` - DisplayName map[string]string `json:"display_name"` - SupportContact []string `json:"support_contact"` - Endpoints Endpoints `json:"endpoints"` - Profiles ProfileInfo `json:"profiles"` - StartTime time.Time `json:"start_time"` - EndTime time.Time `json:"expire_time"` - Type string `json:"server_type"` - httpClient *http.Client -} - -func (b *Base) InitializeEndpoints() error { - ep, err := APIGetEndpoints(b.URL, b.httpClient) - if err != nil { - return err - } - b.Endpoints = *ep - return nil -} - -func (b *Base) ValidProfiles(wireguardSupport bool) ProfileInfo { - var valid []Profile - for _, p := range b.Profiles.Info.ProfileList { - // Not a valid profile because it does not support openvpn - // Also the client does not support wireguard - if !p.SupportsOpenVPN() && !wireguardSupport { - continue - } - valid = append(valid, p) - } - return ProfileInfo{ - Current: b.Profiles.Current, - Info: ProfileListInfo{ProfileList: valid}, - } + URL string `json:"base_url"` + DisplayName map[string]string `json:"display_name"` + SupportContact []string `json:"support_contact"` + Endpoints endpoints.Endpoints `json:"endpoints"` + Profiles profile.Info `json:"profiles"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"expire_time"` + Type server.Type `json:"server_type"` + HTTPClient *http.Client `json:"-"` } // RenewButtonTime returns the time when the renew button should be shown for the server diff --git a/internal/server/custom.go b/internal/server/custom.go deleted file mode 100644 index 6171e24..0000000 --- a/internal/server/custom.go +++ /dev/null @@ -1,35 +0,0 @@ -package server - -import ( - "github.com/go-errors/errors" -) - -func (ss *Servers) SetCustomServer(server Server) error { - b, err := server.Base() - if err != nil { - return err - } - - if b.Type != "custom_server" { - return errors.New("not a custom server") - } - - if _, ok := ss.CustomServers.Map[b.URL]; ok { - ss.CustomServers.CurrentURL = b.URL - ss.IsType = CustomServerType - } else { - return errors.Errorf("this server is not yet added as a custom server: %s", b.URL) - } - return nil -} - -func (ss *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) { - if srv, ok := ss.CustomServers.Map[url]; ok { - return srv, nil - } - return nil, errors.Errorf("failed to get institute access server - no custom server with URL '%s'", url) -} - -func (ss *Servers) RemoveCustomServer(url string) { - ss.CustomServers.Remove(url) -} diff --git a/internal/server/custom/custom.go b/internal/server/custom/custom.go new file mode 100644 index 0000000..14a72a5 --- /dev/null +++ b/internal/server/custom/custom.go @@ -0,0 +1,31 @@ +package custom + +import ( + "context" + + "github.com/eduvpn/eduvpn-common/internal/server/api" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/server/institute" + "github.com/eduvpn/eduvpn-common/types/server" +) + +type ( + Server = institute.Server + Servers = institute.Servers +) + +func New(ctx context.Context, url string) (*Server, error) { + b := base.Base{ + URL: url, + DisplayName: map[string]string{"en": url}, + Type: server.TypeCustom, + } + if err := api.Endpoints(ctx, &b); err != nil { + return nil, err + } + API := b.Endpoints.API.V3 + + s := &Server{Basic: b} + s.Auth.Init(url, API.Authorization, API.Token) + return s, nil +} diff --git a/internal/server/endpoints/endpoints.go b/internal/server/endpoints/endpoints.go new file mode 100644 index 0000000..75bca55 --- /dev/null +++ b/internal/server/endpoints/endpoints.go @@ -0,0 +1,53 @@ +package endpoints + +import ( + "net/url" + + "github.com/go-errors/errors" +) + +type List struct { + API string `json:"api_endpoint"` + Authorization string `json:"authorization_endpoint"` + Token string `json:"token_endpoint"` +} + +type Versions struct { + V2 List `json:"http://eduvpn.org/api#2"` + V3 List `json:"http://eduvpn.org/api#3"` +} + +// Endpoints defines the json format for /.well-known/vpn-user-portal". +type Endpoints struct { + API Versions `json:"api"` + V string `json:"v"` +} + +func (e Endpoints) Validate() error { + v3 := e.API.V3 + pAPI, err := url.Parse(v3.API) + if err != nil { + return errors.WrapPrefix(err, "failed to parse API endpoint", 0) + } + pAuth, err := url.Parse(v3.Authorization) + if err != nil { + return errors.WrapPrefix(err, "failed to parse API authorization endpoint", 0) + } + pToken, err := url.Parse(v3.Token) + if err != nil { + return errors.WrapPrefix(err, "failed to parse API token endpoint", 0) + } + if pAPI.Scheme != pAuth.Scheme { + return errors.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme) + } + if pAPI.Scheme != pToken.Scheme { + return errors.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme) + } + if pAPI.Host != pAuth.Host { + return errors.Errorf("API host: '%v', is not equal to authorization host: '%v'", pAPI.Host, pAuth.Host) + } + if pAPI.Host != pToken.Host { + return errors.Errorf("API host: '%v', is not equal to token host: '%v'", pAPI.Host, pToken.Host) + } + return nil +} diff --git a/internal/server/institute/institute.go b/internal/server/institute/institute.go new file mode 100644 index 0000000..ada1977 --- /dev/null +++ b/internal/server/institute/institute.go @@ -0,0 +1,106 @@ +package institute + +import ( + "context" + + "github.com/eduvpn/eduvpn-common/internal/oauth" + "github.com/eduvpn/eduvpn-common/internal/server/api" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/types/server" + "github.com/go-errors/errors" +) + +type Server struct { + // An instute access server has its own OAuth + Auth oauth.OAuth `json:"oauth"` + + // Embed the server base + Basic base.Base `json:"base"` +} + +type Servers struct { + Map map[string]*Server `json:"map"` + CurrentURL string `json:"current_url"` +} + +func New( + ctx context.Context, + url string, + name map[string]string, + supportContact []string, +) (*Server, error) { + b := base.Base{ + URL: url, + DisplayName: name, + SupportContact: supportContact, + Type: server.TypeInstituteAccess, + } + if err := api.Endpoints(ctx, &b); err != nil { + return nil, err + } + API := b.Endpoints.API.V3 + + s := &Server{Basic: b} + s.Auth.Init(url, API.Authorization, API.Token) + return s, nil +} + +func (s *Servers) Current() (*Server, error) { + if s.Map == nil { + return nil, errors.Errorf("No map is found when getting the current server") + } + + srv, ok := s.Map[s.CurrentURL] + if !ok || srv == nil { + return nil, errors.Errorf("server not found") + } + return srv, nil +} + +func (s *Servers) Remove(url string) error { + // check if it is in the map to begin with + if _, ok := s.Map[url]; ok { + delete(s.Map, url) + } else { + return errors.Errorf("cannot remove URL: %v, not found in list", url) + } + + // Reset the current url + if s.CurrentURL == url { + s.CurrentURL = "" + } + return nil +} + +func (s *Servers) Add(srv *Server) { + if s.Map == nil { + s.Map = make(map[string]*Server) + } + s.Map[srv.Basic.URL] = srv +} + +func (s *Server) TemplateAuth() func(string) string { + return func(authURL string) string { + return authURL + } +} + +func (s *Server) Base() (*base.Base, error) { + return &s.Basic, nil +} + +func (s *Server) OAuth() *oauth.OAuth { + return &s.Auth +} + +func (s *Server) NeedsLocation() bool { + return false +} + +func (s *Server) Public() (interface{}, error) { + return &server.Server{ + DisplayName: s.Basic.DisplayName, + Identifier: s.Basic.URL, + Profiles: s.Basic.Profiles.Public(), + }, nil +} diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go deleted file mode 100644 index ebafb26..0000000 --- a/internal/server/instituteaccess.go +++ /dev/null @@ -1,114 +0,0 @@ -package server - -import ( - "github.com/eduvpn/eduvpn-common/internal/discovery" - "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/go-errors/errors" -) - -type InstituteAccessServer struct { - // An instute access server has its own OAuth - Auth oauth.OAuth `json:"oauth"` - - // Embed the server base - Basic Base `json:"base"` -} - -type InstituteAccessServers struct { - Map map[string]*InstituteAccessServer `json:"map"` - CurrentURL string `json:"current_url"` -} - -func (ss *Servers) SetInstituteAccess(srv Server) error { - b, err := srv.Base() - if err != nil { - return err - } - - if b.Type != "institute_access" { - return errors.Errorf("not an institute access server, URL: %s, type: %s", b.URL, b.Type) - } - - if _, ok := ss.InstituteServers.Map[b.URL]; ok { - ss.InstituteServers.CurrentURL = b.URL - ss.IsType = InstituteAccessServerType - } else { - return errors.Errorf("institute access server with URL: %s, is not yet configured", b.URL) - } - return nil -} - -func (ss *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) { - if srv, ok := ss.InstituteServers.Map[url]; ok { - return srv, nil - } - return nil, errors.Errorf("no institute access server with URL: %s", url) -} - -func (ss *Servers) RemoveInstituteAccess(url string) { - ss.InstituteServers.Remove(url) -} - -func (iass *InstituteAccessServers) Remove(url string) { - // Reset the current url - if iass.CurrentURL == url { - iass.CurrentURL = "" - } - - // Delete the url from the map - delete(iass.Map, url) -} - -func (ias *InstituteAccessServer) TemplateAuth() func(string) string { - return func(authURL string) string { - return authURL - } -} - -func (ias *InstituteAccessServer) Base() (*Base, error) { - return &ias.Basic, nil -} - -func (ias *InstituteAccessServer) OAuth() *oauth.OAuth { - return &ias.Auth -} - -func (ias *InstituteAccessServer) RefreshEndpoints(_ *discovery.Discovery) error { - // Re-initialize the endpoints - b, err := ias.Base() - if err != nil { - return err - } - - err = b.InitializeEndpoints() - if err != nil { - return err - } - - // update OAuth - auth := ias.OAuth() - if auth != nil { - auth.BaseAuthorizationURL = b.Endpoints.API.V3.Authorization - auth.TokenURL = b.Endpoints.API.V3.Token - } - return nil -} - -func (ias *InstituteAccessServer) init( - url string, - name map[string]string, - srvType string, - supportContact []string, -) error { - ias.Basic.URL = url - ias.Basic.DisplayName = name - ias.Basic.SupportContact = supportContact - ias.Basic.Type = srvType - err := ias.Basic.InitializeEndpoints() - if err != nil { - return err - } - API := ias.Basic.Endpoints.API.V3 - ias.Auth.Init(url, API.Authorization, API.Token) - return nil -} diff --git a/internal/server/list.go b/internal/server/list.go new file mode 100644 index 0000000..2660102 --- /dev/null +++ b/internal/server/list.go @@ -0,0 +1,179 @@ +package server + +import ( + "context" + + "github.com/eduvpn/eduvpn-common/internal/server/custom" + "github.com/eduvpn/eduvpn-common/internal/server/institute" + "github.com/eduvpn/eduvpn-common/internal/server/secure" + discotypes "github.com/eduvpn/eduvpn-common/types/discovery" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" + "github.com/go-errors/errors" +) + +type List struct { + CustomServers custom.Servers `json:"custom_servers"` + InstituteServers institute.Servers `json:"institute_servers"` + SecureInternetHomeServer secure.Server `json:"secure_internet_home"` + IsType srvtypes.Type `json:"is_secure_internet"` +} + +// HasSecureInternet returns whether or not we have a secure internet server added +func (l *List) HasSecureInternet() bool { + return len(l.SecureInternetHomeServer.BaseMap) > 0 +} + +func (l *List) HasSecureLocation() bool { + return l.SecureInternetHomeServer.CurrentLocation != "" +} + +func (l *List) Current() (Server, error) { + if l.IsType == srvtypes.TypeUnknown { + return nil, errors.New("no current server") + } + if l.IsType == srvtypes.TypeSecureInternet { + if !l.HasSecureLocation() { + return nil, errors.Errorf("Current server is secure internet but there is no secure internet location: %v", l.IsType) + } + return &l.SecureInternetHomeServer, nil + } + + if l.IsType == srvtypes.TypeCustom { + return l.CustomServers.Current() + } + return l.InstituteServers.Current() +} + +func (l *List) AddCustom(ctx context.Context, url string) (Server, error) { + srv, err := custom.New(ctx, url) + if err != nil { + return nil, err + } + l.CustomServers.Add(srv) + return srv, nil +} + +func (l *List) AddInstituteAccess(ctx context.Context, discoServer *discotypes.Server) (Server, error) { + srv, err := institute.New(ctx, discoServer.BaseURL, discoServer.DisplayName, discoServer.SupportContact) + if err != nil { + return nil, err + } + l.InstituteServers.Add(srv) + return srv, nil +} + +func (l *List) AddSecureInternet( + ctx context.Context, + secureOrg *discotypes.Organization, + secureServer *discotypes.Server, +) (*secure.Server, error) { + // If we have specified an organization ID + // We also need to get an authorization template + err := l.SecureInternetHomeServer.Init(ctx, secureOrg, secureServer) + if err != nil { + return nil, err + } + + l.IsType = srvtypes.TypeSecureInternet + return &l.SecureInternetHomeServer, nil +} + +func (l *List) SecureInternet(identifier string) (*secure.Server, error) { + if l.SecureInternetHomeServer.HomeOrganizationID != identifier { + return nil, errors.Errorf("no secure internet home server with identifier: %s", identifier) + } + return &l.SecureInternetHomeServer, nil +} + +func (l *List) SetSecureInternet(server Server) error { + b, err := server.Base() + if err != nil { + return err + } + + if b.Type != srvtypes.TypeSecureInternet { + return errors.New("not a secure internet server") + } + + // The location should already be configured + // TODO: check for location? + l.IsType = srvtypes.TypeSecureInternet + return nil +} + +func (l *List) RemoveSecureInternet(identifier string) error { + oid := l.SecureInternetHomeServer.HomeOrganizationID + if identifier != oid { + return errors.Errorf("cannot remove secure internet server: identifier: %s, is not equal to the Org ID: %s", identifier, oid) + } + // Empty out the struct + l.SecureInternetHomeServer = secure.Server{} + + // If the current server is secure internet, reset to unknown + if l.IsType == srvtypes.TypeSecureInternet { + l.IsType = srvtypes.TypeUnknown + } + return nil +} + +func (l *List) SetInstituteAccess(srv Server) error { + b, err := srv.Base() + if err != nil { + return err + } + + if b.Type != srvtypes.TypeInstituteAccess { + return errors.Errorf("not an institute access server, URL: %s, type: %v", b.URL, b.Type) + } + + if _, ok := l.InstituteServers.Map[b.URL]; ok { + l.InstituteServers.CurrentURL = b.URL + l.IsType = srvtypes.TypeInstituteAccess + } else { + return errors.Errorf("institute access server with URL: %s, is not yet configured", b.URL) + } + return nil +} + +func (l *List) InstituteAccess(url string) (*institute.Server, error) { + if srv, ok := l.InstituteServers.Map[url]; ok { + return srv, nil + } + return nil, errors.Errorf("no institute access server with URL: %s", url) +} + +func (l *List) RemoveInstituteAccess(url string) error { + // TODO: Reset current to unknown? + return l.InstituteServers.Remove(url) +} + +func (l *List) SetCustom(server Server) error { + b, err := server.Base() + if err != nil { + return err + } + + if b.Type != srvtypes.TypeCustom { + return errors.New("not a custom server") + } + + if _, ok := l.CustomServers.Map[b.URL]; ok { + l.CustomServers.CurrentURL = b.URL + l.IsType = srvtypes.TypeCustom + } else { + return errors.Errorf("this server is not yet added as a custom server: %s", b.URL) + } + return nil +} + +func (l *List) CustomServer(url string) (*institute.Server, error) { + if srv, ok := l.CustomServers.Map[url]; ok { + return srv, nil + } + return nil, errors.Errorf("failed to get institute access server - no custom server with URL '%s'", url) +} + +func (l *List) RemoveCustom(url string) error { + // TODO: Reset current to unknown? + return l.CustomServers.Remove(url) +} diff --git a/internal/server/profile.go b/internal/server/profile.go deleted file mode 100644 index d981421..0000000 --- a/internal/server/profile.go +++ /dev/null @@ -1,44 +0,0 @@ -package server - -type Profile struct { - ID string `json:"profile_id"` - DisplayName string `json:"display_name"` - VPNProtoList []string `json:"vpn_proto_list"` - DefaultGateway bool `json:"default_gateway"` -} - -type ProfileListInfo struct { - ProfileList []Profile `json:"profile_list"` -} - -type ProfileInfo struct { - Current string `json:"current_profile"` - Info ProfileListInfo `json:"info"` -} - -func (info ProfileInfo) CurrentProfileIndex() int { - for i, profile := range info.Info.ProfileList { - if profile.ID == info.Current { - return i - } - } - // Default is 'first' profile - return 0 -} - -func (profile *Profile) supportsProtocol(protocol string) bool { - for _, proto := range profile.VPNProtoList { - if proto == protocol { - return true - } - } - return false -} - -func (profile *Profile) SupportsWireguard() bool { - return profile.supportsProtocol("wireguard") -} - -func (profile *Profile) SupportsOpenVPN() bool { - return profile.supportsProtocol("openvpn") -} diff --git a/internal/server/profile/profile.go b/internal/server/profile/profile.go new file mode 100644 index 0000000..7a19685 --- /dev/null +++ b/internal/server/profile/profile.go @@ -0,0 +1,88 @@ +package profile + +import ( + "github.com/eduvpn/eduvpn-common/types/protocol" + "github.com/eduvpn/eduvpn-common/types/server" +) + +type Profile struct { + ID string `json:"profile_id"` + DisplayName string `json:"display_name"` + VPNProtoList []string `json:"vpn_proto_list"` + DefaultGateway bool `json:"default_gateway"` +} + +type ListInfo struct { + ProfileList []Profile `json:"profile_list"` +} + +type Info struct { + Current string `json:"current_profile"` + Info ListInfo `json:"info"` +} + +func (info Info) CurrentProfileIndex() int { + for i, profile := range info.Info.ProfileList { + if profile.ID == info.Current { + return i + } + } + // Default is 'first' profile + return 0 +} + +func (profile *Profile) supportsProtocol(protocol string) bool { + for _, proto := range profile.VPNProtoList { + if proto == protocol { + return true + } + } + return false +} + +func (profile *Profile) SupportsWireguard() bool { + return profile.supportsProtocol("wireguard") +} + +func (profile *Profile) SupportsOpenVPN() bool { + return profile.supportsProtocol("openvpn") +} + +func (info Info) Supported(wireguardSupport bool) []Profile { + var valid []Profile + for _, p := range info.Info.ProfileList { + // Not a valid profile because it does not support openvpn + // Also the client does not support wireguard + if !p.SupportsOpenVPN() && !wireguardSupport { + continue + } + valid = append(valid, p) + } + return valid +} + +func (info Info) Has(id string) bool { + for _, p := range info.Info.ProfileList { + if p.ID == id { + return true + } + } + return false +} + +func (info Info) Public() server.Profiles { + m := make(map[string]server.Profile) + for _, p := range info.Info.ProfileList { + var protocols []protocol.Protocol + for _, ps := range p.VPNProtoList { + protocols = append(protocols, protocol.New(ps)) + } + m[p.ID] = server.Profile{ + DisplayName: map[string]string{ + "en": p.DisplayName, + }, + Protocols: protocols, + } + } + return server.Profiles{Map: m, Current: info.Current} +} diff --git a/internal/server/profile_test.go b/internal/server/profile/profile_test.go index d6a7e9d..e246b5c 100644 --- a/internal/server/profile_test.go +++ b/internal/server/profile/profile_test.go @@ -1,4 +1,4 @@ -package server +package profile import "testing" @@ -86,9 +86,9 @@ func Test_CurrentProfileIndex(t *testing.T) { } for _, tc := range testCases { - pri := &ProfileInfo{ + pri := &Info{ Current: tc.current, - Info: ProfileListInfo{ + Info: ListInfo{ ProfileList: tc.profiles, }, } diff --git a/internal/server/secure/secure.go b/internal/server/secure/secure.go new file mode 100644 index 0000000..6fed010 --- /dev/null +++ b/internal/server/secure/secure.go @@ -0,0 +1,148 @@ +package secure + +import ( + "context" + "github.com/eduvpn/eduvpn-common/internal/oauth" + "github.com/eduvpn/eduvpn-common/internal/server/api" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/util" + discotypes "github.com/eduvpn/eduvpn-common/types/discovery" + "github.com/eduvpn/eduvpn-common/types/server" + "github.com/go-errors/errors" +) + +// Server secure internet server which has its own OAuth tokens +// It specifies the current location url it is connected to. +type Server struct { + Auth oauth.OAuth `json:"oauth"` + DisplayName map[string]string `json:"display_name"` + + // The home server has a list of info for each configured server location + BaseMap map[string]*base.Base `json:"base_map"` + + // We have the authorization URL template, the home organization ID and the current location + AuthorizationTemplate string `json:"authorization_template"` + HomeOrganizationID string `json:"home_organization_id"` + CurrentLocation string `json:"current_location"` +} + +func (s *Server) TemplateAuth() func(string) string { + return func(authURL string) string { + return util.ReplaceWAYF(s.AuthorizationTemplate, authURL, s.HomeOrganizationID) + } +} + +func (s *Server) Base() (*base.Base, error) { + if s.BaseMap == nil { + return nil, errors.Errorf("secure internet map not found") + } + + b, ok := s.BaseMap[s.CurrentLocation] + if !ok { + return nil, errors.Errorf("secure internet base with location '%s' not found", s.CurrentLocation) + } + return b, nil +} + +func (s *Server) OAuth() *oauth.OAuth { + return &s.Auth +} + +func (s *Server) NeedsLocation() bool { + if s.CurrentLocation == "" { + return true + } + if len(s.BaseMap) == 0 { + return true + } + return false +} + +func (s *Server) addLocation(ctx context.Context, locSrv *discotypes.Server) (*base.Base, error) { + // Initialize the base map if it is non-nil + if s.BaseMap == nil { + s.BaseMap = make(map[string]*base.Base) + } + + // Add the location to the base map + b, ok := s.BaseMap[locSrv.CountryCode] + if !ok || b == nil { + // Create the base to be added to the map + b = &base.Base{} + b.URL = locSrv.BaseURL + b.DisplayName = s.DisplayName + b.SupportContact = locSrv.SupportContact + b.Type = server.TypeSecureInternet + if err := api.Endpoints(ctx, b); err != nil { + return nil, err + } + } + + // Ensure it is in the map + s.BaseMap[locSrv.CountryCode] = b + return b, nil +} + +func (s *Server) Location(ctx context.Context, locSrv *discotypes.Server) error { + if _, err := s.addLocation(ctx, locSrv); err != nil { + return err + } + s.CurrentLocation = locSrv.CountryCode + return nil +} + +// Initializes the home server and adds its own location. +func (s *Server) Init( + ctx context.Context, + homeOrg *discotypes.Organization, homeLoc *discotypes.Server, +) error { + if s.HomeOrganizationID != homeOrg.OrgID { + // New home organisation, clear everything + *s = Server{} + } + + // Make sure to set the organization ID + s.HomeOrganizationID = homeOrg.OrgID + s.DisplayName = homeOrg.DisplayName + + // Make sure to set the authorization URL template + s.AuthorizationTemplate = homeLoc.AuthenticationURLTemplate + + b, err := s.addLocation(ctx, homeLoc) + if err != nil { + return err + } + + // set the home location as the current + err = s.Location(ctx, homeLoc) + if err != nil { + return err + } + + // Set the current location to the home location if there is none + if s.CurrentLocation == "" { + s.CurrentLocation = homeLoc.CountryCode + } + + // Make sure oauth contains our endpoints + s.Auth.Init(b.URL, b.Endpoints.API.V3.Authorization, b.Endpoints.API.V3.Token) + return nil +} + +func (s *Server) Public() (interface{}, error) { + b, err := s.Base() + var p server.Profiles + dn := s.DisplayName + if err == nil { + dn = b.DisplayName + p = b.Profiles.Public() + } + return &server.SecureInternet{ + Server: server.Server{ + DisplayName: dn, + Identifier: s.HomeOrganizationID, + Profiles: p, + }, + CountryCode: s.CurrentLocation, + }, nil +} diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go deleted file mode 100644 index 4b42303..0000000 --- a/internal/server/secureinternet.go +++ /dev/null @@ -1,175 +0,0 @@ -package server - -import ( - "github.com/eduvpn/eduvpn-common/internal/discovery" - "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/internal/util" - discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - "github.com/go-errors/errors" -) - -// SecureInternetHomeServer secure internet server which has its own OAuth tokens -// It specifies the current location url it is connected to. -type SecureInternetHomeServer struct { - Auth oauth.OAuth `json:"oauth"` - DisplayName map[string]string `json:"display_name"` - - // The home server has a list of info for each configured server location - BaseMap map[string]*Base `json:"base_map"` - - // We have the authorization URL template, the home organization ID and the current location - AuthorizationTemplate string `json:"authorization_template"` - HomeOrganizationID string `json:"home_organization_id"` - CurrentLocation string `json:"current_location"` -} - -func (ss *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) { - if !ss.HasSecureLocation() { - return nil, errors.Errorf("no secure internet home server") - } - return &ss.SecureInternetHomeServer, nil -} - -func (ss *Servers) SetSecureInternet(server Server) error { - b, err := server.Base() - if err != nil { - return err - } - - if b.Type != "secure_internet" { - return errors.Errorf("not a secure internet server") - } - - // The location should already be configured - // TODO: check for location? - ss.IsType = SecureInternetServerType - return nil -} - -func (ss *Servers) RemoveSecureInternet() { - // Empty out the struct - ss.SecureInternetHomeServer = SecureInternetHomeServer{} - - // If the current server is secure internet, default to custom server - if ss.IsType == SecureInternetServerType { - ss.IsType = CustomServerType - } -} - -func (s *SecureInternetHomeServer) TemplateAuth() func(string) string { - return func(authURL string) string { - return util.ReplaceWAYF(s.AuthorizationTemplate, authURL, s.HomeOrganizationID) - } -} - -func (s *SecureInternetHomeServer) Base() (*Base, error) { - if s.BaseMap == nil { - return nil, errors.Errorf("secure internet map not found") - } - - b, ok := s.BaseMap[s.CurrentLocation] - if !ok { - return nil, errors.Errorf("secure internet base with location '%s' not found", s.CurrentLocation) - } - return b, nil -} - -func (s *SecureInternetHomeServer) OAuth() *oauth.OAuth { - return &s.Auth -} - -func (ss *Servers) HasSecureLocation() bool { - return ss.SecureInternetHomeServer.CurrentLocation != "" -} - -func (s *SecureInternetHomeServer) addLocation(locSrv *discotypes.Server) (*Base, error) { - // Initialize the base map if it is non-nil - if s.BaseMap == nil { - s.BaseMap = make(map[string]*Base) - } - - // Add the location to the base map - b, ok := s.BaseMap[locSrv.CountryCode] - if !ok || b == nil { - // Create the base to be added to the map - b = &Base{} - b.URL = locSrv.BaseURL - b.DisplayName = s.DisplayName - b.SupportContact = locSrv.SupportContact - b.Type = "secure_internet" - if err := b.InitializeEndpoints(); err != nil { - return nil, err - } - } - - // Ensure it is in the map - s.BaseMap[locSrv.CountryCode] = b - return b, nil -} - -// Initializes the home server and adds its own location. -func (s *SecureInternetHomeServer) init( - homeOrg *discotypes.Organization, homeLoc *discotypes.Server, -) error { - if s.HomeOrganizationID != homeOrg.OrgID { - // New home organisation, clear everything - *s = SecureInternetHomeServer{} - } - - // Make sure to set the organization ID - s.HomeOrganizationID = homeOrg.OrgID - s.DisplayName = homeOrg.DisplayName - - // Make sure to set the authorization URL template - s.AuthorizationTemplate = homeLoc.AuthenticationURLTemplate - - b, err := s.addLocation(homeLoc) - if err != nil { - return err - } - - // Set the current location to the home location if there is none - if s.CurrentLocation == "" { - s.CurrentLocation = homeLoc.CountryCode - } - - // Make sure oauth contains our endpoints - s.Auth.Init(b.URL, b.Endpoints.API.V3.Authorization, b.Endpoints.API.V3.Token) - return nil -} - -func (s *SecureInternetHomeServer) RefreshEndpoints(disco *discovery.Discovery) error { - // update OAuth for home server - auth := s.OAuth() - if auth != nil && s.HomeOrganizationID != "" { - _, srv, err := disco.SecureHomeArgs(s.HomeOrganizationID) - if err != nil { - return err - } - if hb, ok := s.BaseMap[srv.CountryCode]; ok && hb != nil { - err := hb.InitializeEndpoints() - if err != nil { - return err - } - auth.BaseAuthorizationURL = hb.Endpoints.API.V3.Authorization - auth.TokenURL = hb.Endpoints.API.V3.Token - } - // already updated, return - if srv.CountryCode == s.CurrentLocation { - return nil - } - } - - // refresh the current location endpoints - // Re-initialize the endpoints - b, err := s.Base() - if err != nil { - return err - } - - err = b.InitializeEndpoints() - if err != nil { - return err - } - return nil -} diff --git a/internal/server/server.go b/internal/server/server.go index 4bd8766..e7229c5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,23 +1,21 @@ package server import ( + "context" "os" "time" "github.com/eduvpn/eduvpn-common/internal/discovery" "github.com/eduvpn/eduvpn-common/internal/oauth" + "github.com/eduvpn/eduvpn-common/internal/server/api" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/server/profile" "github.com/eduvpn/eduvpn-common/internal/wireguard" + "github.com/eduvpn/eduvpn-common/types/protocol" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" ) -type Type int8 - -const ( - CustomServerType Type = iota - InstituteAccessServerType - SecureInternetServerType -) - type Server interface { OAuth() *oauth.OAuth @@ -25,27 +23,13 @@ type Server interface { TemplateAuth() func(string) string // Base returns the server base - Base() (*Base, error) + Base() (*base.Base, error) - // RefreshEndpoints - RefreshEndpoints(*discovery.Discovery) error -} + // NeedsLocation checks if the server needs a secure internet location + NeedsLocation() bool -type EndpointList struct { - API string `json:"api_endpoint"` - Authorization string `json:"authorization_endpoint"` - Token string `json:"token_endpoint"` -} - -type EndpointsVersions struct { - V2 EndpointList `json:"http://eduvpn.org/api#2"` - V3 EndpointList `json:"http://eduvpn.org/api#3"` -} - -// Endpoints defines the json format for /.well-known/vpn-user-portal". -type Endpoints struct { - API EndpointsVersions `json:"api"` - V string `json:"v"` + // Public returns the representation that will be passed over the CGO barrier + Public() (interface{}, error) } func UpdateTokens(srv Server, t oauth.Token) { @@ -56,12 +40,12 @@ func OAuthURL(srv Server, name string) (string, error) { return srv.OAuth().AuthURL(name, srv.TemplateAuth()) } -func OAuthExchange(srv Server) error { - return srv.OAuth().Exchange() +func OAuthExchange(ctx context.Context, srv Server) error { + return srv.OAuth().Exchange(ctx) } -func HeaderToken(srv Server) (string, error) { - return srv.OAuth().AccessToken() +func HeaderToken(ctx context.Context, srv Server) (string, error) { + return srv.OAuth().AccessToken(ctx) } func MarkTokenExpired(srv Server) { @@ -72,16 +56,13 @@ func MarkTokensForRenew(srv Server) { srv.OAuth().SetTokenRenew() } -func NeedsRelogin(srv Server) bool { - _, err := HeaderToken(srv) +func NeedsRelogin(ctx context.Context, srv Server) bool { + // TODO: this error can be a context cancel + _, err := HeaderToken(ctx, srv) return err != nil } -func CancelOAuth(srv Server) { - srv.OAuth().Cancel() -} - -func CurrentProfile(srv Server) (*Profile, error) { +func CurrentProfile(srv Server) (*profile.Profile, error) { b, err := srv.Base() if err != nil { return nil, err @@ -96,19 +77,31 @@ func CurrentProfile(srv Server) (*Profile, error) { return nil, errors.Errorf("profile not found: " + pID) } -func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) { +func ValidProfiles(srv Server, wireguardSupport bool) (*[]profile.Profile, error) { // No error wrapping here otherwise we wrap it too much b, err := srv.Base() if err != nil { return nil, err } - ps := b.ValidProfiles(wireguardSupport) - if len(ps.Info.ProfileList) == 0 { + ps := b.Profiles.Supported(wireguardSupport) + if len(ps) == 0 { return nil, errors.Errorf("no profiles found with supported protocols") } return &ps, nil } +func Profile(srv Server, id string) error { + b, err := srv.Base() + if err != nil { + return err + } + if !b.Profiles.Has(id) { + return errors.Errorf("no profile available with id: %s", id) + } + b.Profiles.Current = id + return nil +} + type ConfigData struct { // The configuration Config string @@ -120,7 +113,18 @@ type ConfigData struct { Tokens oauth.Token } -func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) { +// Public gets the public data from the types package +// dg specifies if this config is default gateway +func (c *ConfigData) Public(dg bool) srvtypes.Configuration { + return srvtypes.Configuration{ + VPNConfig: c.Config, + Protocol: protocol.New(c.Type), + DefaultGateway: dg, + Tokens: c.Tokens.Public(), + } +} + +func wireguardGetConfig(ctx context.Context, srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) { b, err := srv.Base() if err != nil { return nil, err @@ -133,7 +137,7 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi } pub := key.PublicKey().String() - cfg, proto, exp, err := APIConnectWireguard(srv, pID, pub, preferTCP, openVPNSupport) + cfg, proto, exp, err := api.ConnectWireguard(ctx, b, srv.OAuth(), pID, pub, preferTCP, openVPNSupport) if err != nil { return nil, err } @@ -159,13 +163,13 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil } -func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) { +func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigData, error) { b, err := srv.Base() if err != nil { return nil, err } pid := b.Profiles.Current - cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP) + cfg, exp, err := api.ConnectOpenVPN(ctx, b, srv.OAuth(), pid, preferTCP) if err != nil { return nil, err } @@ -184,15 +188,14 @@ func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) { return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil } -func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { - // Get new profiles using the info call - // This does not override the current profile - err := APIInfo(srv) +func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bool, error) { + b, err := srv.Base() if err != nil { return false, err } - - b, err := srv.Base() + // Get new profiles using the info call + // This does not override the current profile + err = api.Info(ctx, b, srv.OAuth()) if err != nil { return false, err } @@ -225,7 +228,18 @@ func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { return true, nil } -func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) { +func RefreshEndpoints(ctx context.Context, srv Server) error { + // Re-initialize the endpoints + // TODO: Make this a warning instead? + b, err := srv.Base() + if err != nil { + return err + } + + return api.Endpoints(ctx, b) +} + +func Config(ctx context.Context, server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) { p, err := CurrentProfile(server) if err != nil { return nil, err @@ -250,10 +264,10 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, case wg: // A wireguard connect call needs to generate a wireguard key and add it to the config // Also the server could send back an OpenVPN config if it supports OpenVPN - cfg, err = wireguardGetConfig(server, preferTCP, ovpn) + cfg, err = wireguardGetConfig(ctx, server, preferTCP, ovpn) // The config only supports OpenVPN case ovpn: - cfg, err = openVPNGetConfig(server, preferTCP) + cfg, err = openVPNGetConfig(ctx, server, preferTCP) // The config supports no available protocol because the profile only supports WireGuard but the client doesn't default: return nil, errors.New("no supported protocol found") @@ -267,6 +281,10 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, return cfg, err } -func Disconnect(server Server) error { - return APIDisconnect(server) +func Disconnect(ctx context.Context, server Server) error { + b, err := server.Base() + if err != nil { + return err + } + return api.Disconnect(ctx, b, server.OAuth()) } diff --git a/internal/server/servers.go b/internal/server/servers.go deleted file mode 100644 index 60c993d..0000000 --- a/internal/server/servers.go +++ /dev/null @@ -1,121 +0,0 @@ -package server - -import ( - discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - "github.com/go-errors/errors" -) - -// TODO: Have a dedicated type for custom servers -type Servers struct { - // A custom server is just an institute access server under the hood - CustomServers InstituteAccessServers `json:"custom_servers"` - InstituteServers InstituteAccessServers `json:"institute_servers"` - SecureInternetHomeServer SecureInternetHomeServer `json:"secure_internet_home"` - IsType Type `json:"is_secure_internet"` -} - -// HasSecureInternet returns whether or not we have a secure internet server added -func (ss *Servers) HasSecureInternet() bool { - return len(ss.SecureInternetHomeServer.BaseMap) > 0 -} - -func (ss *Servers) AddSecureInternet( - secureOrg *discotypes.Organization, - secureServer *discotypes.Server, -) (Server, error) { - // If we have specified an organization ID - // We also need to get an authorization template - err := ss.SecureInternetHomeServer.init(secureOrg, secureServer) - if err != nil { - return nil, err - } - - ss.IsType = SecureInternetServerType - return &ss.SecureInternetHomeServer, nil -} - -func (ss *Servers) GetCurrentServer() (Server, error) { - // TODO(jwijenbergh): Almost certainly the return type should be pointer (*Server) - if ss.IsType == SecureInternetServerType { - if !ss.HasSecureLocation() { - return nil, errors.Errorf("ss.IsType = %v; ss.HasSecureLocation() = false", ss.IsType) - } - return &ss.SecureInternetHomeServer, nil - } - - srvs := &ss.InstituteServers - - if ss.IsType == CustomServerType { - srvs = &ss.CustomServers - } - if srvs.Map == nil { - return nil, errors.Errorf("srvs.Map is nil") - } - - srv, ok := srvs.Map[srvs.CurrentURL] - if !ok || srv == nil { - return nil, errors.Errorf("server not found") - } - return srv, nil -} - -func (ss *Servers) addInstituteAndCustom( - discoServer *discotypes.Server, - isCustom bool, -) (Server, error) { - URL := discoServer.BaseURL - srvs := &ss.InstituteServers - srvType := InstituteAccessServerType - - if isCustom { - srvs = &ss.CustomServers - srvType = CustomServerType - } - - if srvs.Map == nil { - srvs.Map = make(map[string]*InstituteAccessServer) - } - - srv, ok := srvs.Map[URL] - - // initialize the server if it doesn't exist yet - if !ok { - srv = &InstituteAccessServer{} - } - - if err := srv.init(URL, discoServer.DisplayName, discoServer.Type, discoServer.SupportContact); err != nil { - return nil, err - } - srvs.Map[URL] = srv - ss.IsType = srvType - return srv, nil -} - -func (ss *Servers) AddInstituteAccessServer( - instituteServer *discotypes.Server, -) (Server, error) { - return ss.addInstituteAndCustom(instituteServer, false) -} - -func (ss *Servers) AddCustomServer( - customServer *discotypes.Server, -) (Server, error) { - return ss.addInstituteAndCustom(customServer, true) -} - -func (ss *Servers) GetSecureLocation() string { - return ss.SecureInternetHomeServer.CurrentLocation -} - -func (ss *Servers) SetSecureLocation( - chosenLocationServer *discotypes.Server, -) error { - // Make sure to add the current location - - if _, err := ss.SecureInternetHomeServer.addLocation(chosenLocationServer); err != nil { - return err - } - - ss.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode - return nil -} |
