summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2023-04-12 22:52:49 +0200
committerJeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com>2023-09-25 09:43:37 +0200
commita23c3e61c5d89ef67973891b5b3a176c06e1b174 (patch)
treef1eed03b047f8affd3d5123fa5c9e868ac7d8bec /internal
parentee95eb45708e1fa766a63866d26d05d13f23e8c9 (diff)
Refactor: Split internal server into multiple packages
- Pass contexts - Have separate packages for e.g. custom, institute and secure - internet servers, profiles.... - Return types from the public ./types package with a Public() method
Diffstat (limited to 'internal')
-rw-r--r--internal/server/api/api.go (renamed from internal/server/api.go)120
-rw-r--r--internal/server/api/api_test.go (renamed from internal/server/api_test.go)38
-rw-r--r--internal/server/base/base.go (renamed from internal/server/base.go)48
-rw-r--r--internal/server/custom.go35
-rw-r--r--internal/server/custom/custom.go31
-rw-r--r--internal/server/endpoints/endpoints.go53
-rw-r--r--internal/server/institute/institute.go106
-rw-r--r--internal/server/instituteaccess.go114
-rw-r--r--internal/server/list.go179
-rw-r--r--internal/server/profile.go44
-rw-r--r--internal/server/profile/profile.go88
-rw-r--r--internal/server/profile/profile_test.go (renamed from internal/server/profile_test.go)6
-rw-r--r--internal/server/secure/secure.go148
-rw-r--r--internal/server/secureinternet.go175
-rw-r--r--internal/server/server.go130
-rw-r--r--internal/server/servers.go121
16 files changed, 763 insertions, 673 deletions
diff --git a/internal/server/api.go b/internal/server/api/api.go
index 546c02a..9ad6f2d 100644
--- a/internal/server/api.go
+++ b/internal/server/api/api.go
@@ -1,6 +1,7 @@
-package server
+package api
import (
+ "context"
"encoding/json"
"fmt"
"net/http"
@@ -10,65 +11,43 @@ import (
httpw "github.com/eduvpn/eduvpn-common/internal/http"
"github.com/eduvpn/eduvpn-common/internal/log"
+ "github.com/eduvpn/eduvpn-common/internal/oauth"
+ "github.com/eduvpn/eduvpn-common/internal/server/base"
+ "github.com/eduvpn/eduvpn-common/internal/server/endpoints"
+ "github.com/eduvpn/eduvpn-common/internal/server/profile"
"github.com/go-errors/errors"
)
-func validateEndpoints(endpoints Endpoints) error {
- v3 := endpoints.API.V3
- pAPI, err := url.Parse(v3.API)
+func Endpoints(ctx context.Context, b *base.Base) error {
+ uStr, err := httpw.JoinURLPath(b.URL, "/.well-known/vpn-user-portal")
if err != nil {
- return errors.WrapPrefix(err, "failed to parse API endpoint", 0)
- }
- pAuth, err := url.Parse(v3.Authorization)
- if err != nil {
- return errors.WrapPrefix(err, "failed to parse API authorization endpoint", 0)
- }
- pToken, err := url.Parse(v3.Token)
- if err != nil {
- return errors.WrapPrefix(err, "failed to parse API token endpoint", 0)
- }
- if pAPI.Scheme != pAuth.Scheme {
- return errors.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme)
- }
- if pAPI.Scheme != pToken.Scheme {
- return errors.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme)
- }
- if pAPI.Host != pAuth.Host {
- return errors.Errorf("API host: '%v', is not equal to authorization host: '%v'", pAPI.Host, pAuth.Host)
- }
- if pAPI.Host != pToken.Host {
- return errors.Errorf("API host: '%v', is not equal to token host: '%v'", pAPI.Host, pToken.Host)
- }
- return nil
-}
-
-func APIGetEndpoints(baseURL string, client *httpw.Client) (*Endpoints, error) {
- uStr, err := httpw.JoinURLPath(baseURL, "/.well-known/vpn-user-portal")
- if err != nil {
- return nil, err
+ return err
}
- if client == nil {
- client = httpw.NewClient()
+ if b.HTTPClient == nil {
+ b.HTTPClient = httpw.NewClient()
}
- _, body, err := client.Get(uStr)
+ _, body, err := b.HTTPClient.Get(ctx, uStr)
if err != nil {
- return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
+ return errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
- ep := Endpoints{}
+ ep := endpoints.Endpoints{}
if err = json.Unmarshal(body, &ep); err != nil {
- return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
+ return errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
- err = validateEndpoints(ep)
+ err = ep.Validate()
if err != nil {
- return nil, err
+ return err
}
- return &ep, nil
+ b.Endpoints = ep
+ return nil
}
-func apiAuthorized(
- srv Server,
+func authorized(
+ ctx context.Context,
+ b *base.Base,
+ oauth *oauth.OAuth,
method string,
endpoint string,
opts *httpw.OptionalParams,
@@ -78,10 +57,6 @@ func apiAuthorized(
opts = &httpw.OptionalParams{}
}
errorMessage := "failed API authorized"
- b, err := srv.Base()
- if err != nil {
- return nil, nil, errors.WrapPrefix(err, errorMessage, 0)
- }
// Join the paths
u, err := url.Parse(b.Endpoints.API.V3.API)
@@ -91,7 +66,7 @@ func apiAuthorized(
u.Path = path.Join(u.Path, endpoint)
// Make sure the tokens are valid, this will return an error if re-login is needed
- t, err := HeaderToken(srv)
+ t, err := oauth.AccessToken(ctx)
if err != nil {
return nil, nil, errors.WrapPrefix(err, errorMessage, 0)
}
@@ -105,19 +80,21 @@ func apiAuthorized(
}
// Create a client if it doesn't exist
- if b.httpClient == nil {
- b.httpClient = httpw.NewClient()
+ if b.HTTPClient == nil {
+ b.HTTPClient = httpw.NewClient()
}
- return b.httpClient.Do(method, u.String(), opts)
+ return b.HTTPClient.Do(ctx, method, u.String(), opts)
}
-func apiAuthorizedRetry(
- srv Server,
+func authorizedRetry(
+ ctx context.Context,
+ b *base.Base,
+ auth *oauth.OAuth,
method string,
endpoint string,
opts *httpw.OptionalParams,
) (http.Header, []byte, error) {
- h, body, err := apiAuthorized(srv, method, endpoint, opts)
+ h, body, err := authorized(ctx, b, auth, method, endpoint, opts)
if err == nil {
return h, body, nil
}
@@ -127,27 +104,22 @@ func apiAuthorizedRetry(
if errors.As(err, &statErr) && statErr.Status == 401 {
log.Logger.Debugf("Got a 401 error after HTTP method: %s, endpoint: %s. Marking token as expired...", method, endpoint)
// Mark the token as expired and retry, so we trigger the refresh flow
- MarkTokenExpired(srv)
- h, body, err = apiAuthorized(srv, method, endpoint, opts)
+ auth.SetTokenExpired()
+ h, body, err = authorized(ctx, b, auth, method, endpoint, opts)
}
return h, body, err
}
-func APIInfo(srv Server) error {
- _, body, err := apiAuthorizedRetry(srv, http.MethodGet, "/info", nil)
+func Info(ctx context.Context, b *base.Base, auth *oauth.OAuth) error {
+ _, body, err := authorizedRetry(ctx, b, auth, http.MethodGet, "/info", nil)
if err != nil {
return err
}
- profiles := ProfileInfo{}
+ profiles := profile.Info{}
if err = json.Unmarshal(body, &profiles); err != nil {
return errors.WrapPrefix(err, "failed API /info", 0)
}
- b, err := srv.Base()
- if err != nil {
- return err
- }
-
// Store the profiles and make sure that the current profile is not overwritten
prev := b.Profiles.Current
b.Profiles = profiles
@@ -163,8 +135,10 @@ func boolToYesNo(preferTCP bool) string {
return "no"
}
-func APIConnectWireguard(
- srv Server,
+func ConnectWireguard(
+ ctx context.Context,
+ b *base.Base,
+ auth *oauth.OAuth,
profileID string,
pubkey string,
preferTCP bool,
@@ -186,7 +160,7 @@ func APIConnectWireguard(
"public_key": {pubkey},
"prefer_tcp": {boolToYesNo(preferTCP)},
}
- h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect",
+ h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect",
&httpw.OptionalParams{Headers: hdrs, Body: vals})
if err != nil {
return "", "", time.Time{}, err
@@ -208,7 +182,7 @@ func APIConnectWireguard(
return string(body), content, expTime, nil
}
-func APIConnectOpenVPN(srv Server, profileID string, preferTCP bool) (string, time.Time, error) {
+func ConnectOpenVPN(ctx context.Context, b *base.Base, auth *oauth.OAuth, profileID string, preferTCP bool) (string, time.Time, error) {
hdrs := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-openvpn-profile"},
@@ -219,7 +193,7 @@ func APIConnectOpenVPN(srv Server, profileID string, preferTCP bool) (string, ti
"prefer_tcp": {boolToYesNo(preferTCP)},
}
- h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect",
+ h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect",
&httpw.OptionalParams{Headers: hdrs, Body: vals})
if err != nil {
return "", time.Time{}, err
@@ -234,10 +208,10 @@ func APIConnectOpenVPN(srv Server, profileID string, preferTCP bool) (string, ti
return string(body), expT, nil
}
-// APIDisconnect disconnects from the API.
-func APIDisconnect(server Server) error {
+// Disconnect disconnects the VPN using the API.
+func Disconnect(ctx context.Context, b *base.Base, auth *oauth.OAuth) error {
// The timeout is a bit lower here such that this does not take a too long time for disconnecting
// Clients may wish to retry this
- _, _, err := apiAuthorized(server, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second})
+ _, _, err := authorized(ctx, b, auth, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second})
return err
}
diff --git a/internal/server/api_test.go b/internal/server/api/api_test.go
index b1e3550..7509a30 100644
--- a/internal/server/api_test.go
+++ b/internal/server/api/api_test.go
@@ -1,11 +1,14 @@
-package server
+package api
import (
+ "context"
"encoding/json"
"fmt"
"net/http"
"testing"
+ "github.com/eduvpn/eduvpn-common/internal/server/base"
+ "github.com/eduvpn/eduvpn-common/internal/server/endpoints"
"github.com/eduvpn/eduvpn-common/internal/test"
"github.com/go-errors/errors"
)
@@ -17,7 +20,7 @@ func getErrorMsg(err error) string {
return err.Error()
}
-func compareEndpoints(ep1 Endpoints, ep2 Endpoints) bool {
+func compareEndpoints(ep1 endpoints.Endpoints, ep2 endpoints.Endpoints) bool {
v3_1 := ep1.API.V3
v3_2 := ep2.API.V3
return v3_1.API == v3_2.API && v3_1.Authorization == v3_2.Authorization && v3_1.Token == v3_2.Token
@@ -38,11 +41,11 @@ func Test_APIGetEndpoints(t *testing.T) {
}
testCases := []struct {
- epl EndpointList
+ epl endpoints.List
err error
}{
{
- epl: EndpointList{
+ epl: endpoints.List{
API: "https://example.com/1",
Authorization: "https://example.com/2",
Token: "https://example.com/3",
@@ -50,7 +53,7 @@ func Test_APIGetEndpoints(t *testing.T) {
err: nil,
},
{
- epl: EndpointList{
+ epl: endpoints.List{
API: "http://example.com/1",
Authorization: "https://example.com/2",
Token: "https://example.com/3",
@@ -58,7 +61,7 @@ func Test_APIGetEndpoints(t *testing.T) {
err: errors.New("API scheme: 'http', is not equal to authorization scheme: 'https'"),
},
{
- epl: EndpointList{
+ epl: endpoints.List{
API: "https://example.com/1",
Authorization: "https://example.com/2",
Token: "ftp://example.com/3",
@@ -66,7 +69,7 @@ func Test_APIGetEndpoints(t *testing.T) {
err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"),
},
{
- epl: EndpointList{
+ epl: endpoints.List{
API: "https://malicious.com/1",
Authorization: "https://example.com/2",
Token: "https://example.com/3",
@@ -74,7 +77,7 @@ func Test_APIGetEndpoints(t *testing.T) {
err: errors.New("API host: 'malicious.com', is not equal to authorization host: 'example.com'"),
},
{
- epl: EndpointList{
+ epl: endpoints.List{
API: "https://example.com/1",
Authorization: "https://example.com/2",
Token: "https://malicious.com/3",
@@ -82,7 +85,7 @@ func Test_APIGetEndpoints(t *testing.T) {
err: errors.New("API host: 'example.com', is not equal to token host: 'malicious.com'"),
},
{
- epl: EndpointList{
+ epl: endpoints.List{
API: "https://example.com/1",
Authorization: "https://malicious.com/2",
Token: "https://example.com/3",
@@ -92,8 +95,8 @@ func Test_APIGetEndpoints(t *testing.T) {
}
for _, tc := range testCases {
- ep := &Endpoints{
- API: EndpointsVersions{
+ ep := &endpoints.Endpoints{
+ API: endpoints.Versions{
V3: tc.epl,
},
}
@@ -108,7 +111,11 @@ func Test_APIGetEndpoints(t *testing.T) {
fmt.Fprintln(w, string(jsonStr))
}))
- gotEP, err := APIGetEndpoints(s.URL, c)
+ b := &base.Base{
+ URL: s.URL,
+ HTTPClient: c,
+ }
+ err = Endpoints(context.Background(), b)
if getErrorMsg(err) != getErrorMsg(tc.err) {
t.Fatalf("Errors not equal, want err: %v, got: %v", tc.err, err)
}
@@ -119,12 +126,9 @@ func Test_APIGetEndpoints(t *testing.T) {
if ep == nil {
t.Fatalf("No test case endpoints")
}
- if gotEP == nil {
- t.Fatalf("Got no endpoints for nil error")
- }
// if no error then the endpoints should be equal
- if !compareEndpoints(*ep, *gotEP) {
- t.Fatalf("Endpoints are not equal, got: %v, want: %v", gotEP, ep)
+ if !compareEndpoints(*ep, b.Endpoints) {
+ t.Fatalf("Endpoints are not equal, got: %v, want: %v", b.Endpoints, ep)
}
}
}
diff --git a/internal/server/base.go b/internal/server/base/base.go
index c7a9adc..d483dad 100644
--- a/internal/server/base.go
+++ b/internal/server/base/base.go
@@ -1,47 +1,25 @@
-package server
+package base
import (
"time"
"github.com/eduvpn/eduvpn-common/internal/http"
+ "github.com/eduvpn/eduvpn-common/internal/server/endpoints"
+ "github.com/eduvpn/eduvpn-common/internal/server/profile"
+ "github.com/eduvpn/eduvpn-common/types/server"
)
// Base is the base type for servers.
type Base struct {
- URL string `json:"base_url"`
- DisplayName map[string]string `json:"display_name"`
- SupportContact []string `json:"support_contact"`
- Endpoints Endpoints `json:"endpoints"`
- Profiles ProfileInfo `json:"profiles"`
- StartTime time.Time `json:"start_time"`
- EndTime time.Time `json:"expire_time"`
- Type string `json:"server_type"`
- httpClient *http.Client
-}
-
-func (b *Base) InitializeEndpoints() error {
- ep, err := APIGetEndpoints(b.URL, b.httpClient)
- if err != nil {
- return err
- }
- b.Endpoints = *ep
- return nil
-}
-
-func (b *Base) ValidProfiles(wireguardSupport bool) ProfileInfo {
- var valid []Profile
- for _, p := range b.Profiles.Info.ProfileList {
- // Not a valid profile because it does not support openvpn
- // Also the client does not support wireguard
- if !p.SupportsOpenVPN() && !wireguardSupport {
- continue
- }
- valid = append(valid, p)
- }
- return ProfileInfo{
- Current: b.Profiles.Current,
- Info: ProfileListInfo{ProfileList: valid},
- }
+ URL string `json:"base_url"`
+ DisplayName map[string]string `json:"display_name"`
+ SupportContact []string `json:"support_contact"`
+ Endpoints endpoints.Endpoints `json:"endpoints"`
+ Profiles profile.Info `json:"profiles"`
+ StartTime time.Time `json:"start_time"`
+ EndTime time.Time `json:"expire_time"`
+ Type server.Type `json:"server_type"`
+ HTTPClient *http.Client `json:"-"`
}
// RenewButtonTime returns the time when the renew button should be shown for the server
diff --git a/internal/server/custom.go b/internal/server/custom.go
deleted file mode 100644
index 6171e24..0000000
--- a/internal/server/custom.go
+++ /dev/null
@@ -1,35 +0,0 @@
-package server
-
-import (
- "github.com/go-errors/errors"
-)
-
-func (ss *Servers) SetCustomServer(server Server) error {
- b, err := server.Base()
- if err != nil {
- return err
- }
-
- if b.Type != "custom_server" {
- return errors.New("not a custom server")
- }
-
- if _, ok := ss.CustomServers.Map[b.URL]; ok {
- ss.CustomServers.CurrentURL = b.URL
- ss.IsType = CustomServerType
- } else {
- return errors.Errorf("this server is not yet added as a custom server: %s", b.URL)
- }
- return nil
-}
-
-func (ss *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) {
- if srv, ok := ss.CustomServers.Map[url]; ok {
- return srv, nil
- }
- return nil, errors.Errorf("failed to get institute access server - no custom server with URL '%s'", url)
-}
-
-func (ss *Servers) RemoveCustomServer(url string) {
- ss.CustomServers.Remove(url)
-}
diff --git a/internal/server/custom/custom.go b/internal/server/custom/custom.go
new file mode 100644
index 0000000..14a72a5
--- /dev/null
+++ b/internal/server/custom/custom.go
@@ -0,0 +1,31 @@
+package custom
+
+import (
+ "context"
+
+ "github.com/eduvpn/eduvpn-common/internal/server/api"
+ "github.com/eduvpn/eduvpn-common/internal/server/base"
+ "github.com/eduvpn/eduvpn-common/internal/server/institute"
+ "github.com/eduvpn/eduvpn-common/types/server"
+)
+
+type (
+ Server = institute.Server
+ Servers = institute.Servers
+)
+
+func New(ctx context.Context, url string) (*Server, error) {
+ b := base.Base{
+ URL: url,
+ DisplayName: map[string]string{"en": url},
+ Type: server.TypeCustom,
+ }
+ if err := api.Endpoints(ctx, &b); err != nil {
+ return nil, err
+ }
+ API := b.Endpoints.API.V3
+
+ s := &Server{Basic: b}
+ s.Auth.Init(url, API.Authorization, API.Token)
+ return s, nil
+}
diff --git a/internal/server/endpoints/endpoints.go b/internal/server/endpoints/endpoints.go
new file mode 100644
index 0000000..75bca55
--- /dev/null
+++ b/internal/server/endpoints/endpoints.go
@@ -0,0 +1,53 @@
+package endpoints
+
+import (
+ "net/url"
+
+ "github.com/go-errors/errors"
+)
+
+type List struct {
+ API string `json:"api_endpoint"`
+ Authorization string `json:"authorization_endpoint"`
+ Token string `json:"token_endpoint"`
+}
+
+type Versions struct {
+ V2 List `json:"http://eduvpn.org/api#2"`
+ V3 List `json:"http://eduvpn.org/api#3"`
+}
+
+// Endpoints defines the json format for /.well-known/vpn-user-portal".
+type Endpoints struct {
+ API Versions `json:"api"`
+ V string `json:"v"`
+}
+
+func (e Endpoints) Validate() error {
+ v3 := e.API.V3
+ pAPI, err := url.Parse(v3.API)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed to parse API endpoint", 0)
+ }
+ pAuth, err := url.Parse(v3.Authorization)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed to parse API authorization endpoint", 0)
+ }
+ pToken, err := url.Parse(v3.Token)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed to parse API token endpoint", 0)
+ }
+ if pAPI.Scheme != pAuth.Scheme {
+ return errors.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme)
+ }
+ if pAPI.Scheme != pToken.Scheme {
+ return errors.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme)
+ }
+ if pAPI.Host != pAuth.Host {
+ return errors.Errorf("API host: '%v', is not equal to authorization host: '%v'", pAPI.Host, pAuth.Host)
+ }
+ if pAPI.Host != pToken.Host {
+ return errors.Errorf("API host: '%v', is not equal to token host: '%v'", pAPI.Host, pToken.Host)
+ }
+ return nil
+}
diff --git a/internal/server/institute/institute.go b/internal/server/institute/institute.go
new file mode 100644
index 0000000..ada1977
--- /dev/null
+++ b/internal/server/institute/institute.go
@@ -0,0 +1,106 @@
+package institute
+
+import (
+ "context"
+
+ "github.com/eduvpn/eduvpn-common/internal/oauth"
+ "github.com/eduvpn/eduvpn-common/internal/server/api"
+ "github.com/eduvpn/eduvpn-common/internal/server/base"
+ "github.com/eduvpn/eduvpn-common/types/server"
+ "github.com/go-errors/errors"
+)
+
+type Server struct {
+ // An instute access server has its own OAuth
+ Auth oauth.OAuth `json:"oauth"`
+
+ // Embed the server base
+ Basic base.Base `json:"base"`
+}
+
+type Servers struct {
+ Map map[string]*Server `json:"map"`
+ CurrentURL string `json:"current_url"`
+}
+
+func New(
+ ctx context.Context,
+ url string,
+ name map[string]string,
+ supportContact []string,
+) (*Server, error) {
+ b := base.Base{
+ URL: url,
+ DisplayName: name,
+ SupportContact: supportContact,
+ Type: server.TypeInstituteAccess,
+ }
+ if err := api.Endpoints(ctx, &b); err != nil {
+ return nil, err
+ }
+ API := b.Endpoints.API.V3
+
+ s := &Server{Basic: b}
+ s.Auth.Init(url, API.Authorization, API.Token)
+ return s, nil
+}
+
+func (s *Servers) Current() (*Server, error) {
+ if s.Map == nil {
+ return nil, errors.Errorf("No map is found when getting the current server")
+ }
+
+ srv, ok := s.Map[s.CurrentURL]
+ if !ok || srv == nil {
+ return nil, errors.Errorf("server not found")
+ }
+ return srv, nil
+}
+
+func (s *Servers) Remove(url string) error {
+ // check if it is in the map to begin with
+ if _, ok := s.Map[url]; ok {
+ delete(s.Map, url)
+ } else {
+ return errors.Errorf("cannot remove URL: %v, not found in list", url)
+ }
+
+ // Reset the current url
+ if s.CurrentURL == url {
+ s.CurrentURL = ""
+ }
+ return nil
+}
+
+func (s *Servers) Add(srv *Server) {
+ if s.Map == nil {
+ s.Map = make(map[string]*Server)
+ }
+ s.Map[srv.Basic.URL] = srv
+}
+
+func (s *Server) TemplateAuth() func(string) string {
+ return func(authURL string) string {
+ return authURL
+ }
+}
+
+func (s *Server) Base() (*base.Base, error) {
+ return &s.Basic, nil
+}
+
+func (s *Server) OAuth() *oauth.OAuth {
+ return &s.Auth
+}
+
+func (s *Server) NeedsLocation() bool {
+ return false
+}
+
+func (s *Server) Public() (interface{}, error) {
+ return &server.Server{
+ DisplayName: s.Basic.DisplayName,
+ Identifier: s.Basic.URL,
+ Profiles: s.Basic.Profiles.Public(),
+ }, nil
+}
diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go
deleted file mode 100644
index ebafb26..0000000
--- a/internal/server/instituteaccess.go
+++ /dev/null
@@ -1,114 +0,0 @@
-package server
-
-import (
- "github.com/eduvpn/eduvpn-common/internal/discovery"
- "github.com/eduvpn/eduvpn-common/internal/oauth"
- "github.com/go-errors/errors"
-)
-
-type InstituteAccessServer struct {
- // An instute access server has its own OAuth
- Auth oauth.OAuth `json:"oauth"`
-
- // Embed the server base
- Basic Base `json:"base"`
-}
-
-type InstituteAccessServers struct {
- Map map[string]*InstituteAccessServer `json:"map"`
- CurrentURL string `json:"current_url"`
-}
-
-func (ss *Servers) SetInstituteAccess(srv Server) error {
- b, err := srv.Base()
- if err != nil {
- return err
- }
-
- if b.Type != "institute_access" {
- return errors.Errorf("not an institute access server, URL: %s, type: %s", b.URL, b.Type)
- }
-
- if _, ok := ss.InstituteServers.Map[b.URL]; ok {
- ss.InstituteServers.CurrentURL = b.URL
- ss.IsType = InstituteAccessServerType
- } else {
- return errors.Errorf("institute access server with URL: %s, is not yet configured", b.URL)
- }
- return nil
-}
-
-func (ss *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) {
- if srv, ok := ss.InstituteServers.Map[url]; ok {
- return srv, nil
- }
- return nil, errors.Errorf("no institute access server with URL: %s", url)
-}
-
-func (ss *Servers) RemoveInstituteAccess(url string) {
- ss.InstituteServers.Remove(url)
-}
-
-func (iass *InstituteAccessServers) Remove(url string) {
- // Reset the current url
- if iass.CurrentURL == url {
- iass.CurrentURL = ""
- }
-
- // Delete the url from the map
- delete(iass.Map, url)
-}
-
-func (ias *InstituteAccessServer) TemplateAuth() func(string) string {
- return func(authURL string) string {
- return authURL
- }
-}
-
-func (ias *InstituteAccessServer) Base() (*Base, error) {
- return &ias.Basic, nil
-}
-
-func (ias *InstituteAccessServer) OAuth() *oauth.OAuth {
- return &ias.Auth
-}
-
-func (ias *InstituteAccessServer) RefreshEndpoints(_ *discovery.Discovery) error {
- // Re-initialize the endpoints
- b, err := ias.Base()
- if err != nil {
- return err
- }
-
- err = b.InitializeEndpoints()
- if err != nil {
- return err
- }
-
- // update OAuth
- auth := ias.OAuth()
- if auth != nil {
- auth.BaseAuthorizationURL = b.Endpoints.API.V3.Authorization
- auth.TokenURL = b.Endpoints.API.V3.Token
- }
- return nil
-}
-
-func (ias *InstituteAccessServer) init(
- url string,
- name map[string]string,
- srvType string,
- supportContact []string,
-) error {
- ias.Basic.URL = url
- ias.Basic.DisplayName = name
- ias.Basic.SupportContact = supportContact
- ias.Basic.Type = srvType
- err := ias.Basic.InitializeEndpoints()
- if err != nil {
- return err
- }
- API := ias.Basic.Endpoints.API.V3
- ias.Auth.Init(url, API.Authorization, API.Token)
- return nil
-}
diff --git a/internal/server/list.go b/internal/server/list.go
new file mode 100644
index 0000000..2660102
--- /dev/null
+++ b/internal/server/list.go
@@ -0,0 +1,179 @@
+package server
+
+import (
+ "context"
+
+ "github.com/eduvpn/eduvpn-common/internal/server/custom"
+ "github.com/eduvpn/eduvpn-common/internal/server/institute"
+ "github.com/eduvpn/eduvpn-common/internal/server/secure"
+ discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
+ srvtypes "github.com/eduvpn/eduvpn-common/types/server"
+ "github.com/go-errors/errors"
+)
+
+type List struct {
+ CustomServers custom.Servers `json:"custom_servers"`
+ InstituteServers institute.Servers `json:"institute_servers"`
+ SecureInternetHomeServer secure.Server `json:"secure_internet_home"`
+ IsType srvtypes.Type `json:"is_secure_internet"`
+}
+
+// HasSecureInternet returns whether or not we have a secure internet server added
+func (l *List) HasSecureInternet() bool {
+ return len(l.SecureInternetHomeServer.BaseMap) > 0
+}
+
+func (l *List) HasSecureLocation() bool {
+ return l.SecureInternetHomeServer.CurrentLocation != ""
+}
+
+func (l *List) Current() (Server, error) {
+ if l.IsType == srvtypes.TypeUnknown {
+ return nil, errors.New("no current server")
+ }
+ if l.IsType == srvtypes.TypeSecureInternet {
+ if !l.HasSecureLocation() {
+ return nil, errors.Errorf("Current server is secure internet but there is no secure internet location: %v", l.IsType)
+ }
+ return &l.SecureInternetHomeServer, nil
+ }
+
+ if l.IsType == srvtypes.TypeCustom {
+ return l.CustomServers.Current()
+ }
+ return l.InstituteServers.Current()
+}
+
+func (l *List) AddCustom(ctx context.Context, url string) (Server, error) {
+ srv, err := custom.New(ctx, url)
+ if err != nil {
+ return nil, err
+ }
+ l.CustomServers.Add(srv)
+ return srv, nil
+}
+
+func (l *List) AddInstituteAccess(ctx context.Context, discoServer *discotypes.Server) (Server, error) {
+ srv, err := institute.New(ctx, discoServer.BaseURL, discoServer.DisplayName, discoServer.SupportContact)
+ if err != nil {
+ return nil, err
+ }
+ l.InstituteServers.Add(srv)
+ return srv, nil
+}
+
+func (l *List) AddSecureInternet(
+ ctx context.Context,
+ secureOrg *discotypes.Organization,
+ secureServer *discotypes.Server,
+) (*secure.Server, error) {
+ // If we have specified an organization ID
+ // We also need to get an authorization template
+ err := l.SecureInternetHomeServer.Init(ctx, secureOrg, secureServer)
+ if err != nil {
+ return nil, err
+ }
+
+ l.IsType = srvtypes.TypeSecureInternet
+ return &l.SecureInternetHomeServer, nil
+}
+
+func (l *List) SecureInternet(identifier string) (*secure.Server, error) {
+ if l.SecureInternetHomeServer.HomeOrganizationID != identifier {
+ return nil, errors.Errorf("no secure internet home server with identifier: %s", identifier)
+ }
+ return &l.SecureInternetHomeServer, nil
+}
+
+func (l *List) SetSecureInternet(server Server) error {
+ b, err := server.Base()
+ if err != nil {
+ return err
+ }
+
+ if b.Type != srvtypes.TypeSecureInternet {
+ return errors.New("not a secure internet server")
+ }
+
+ // The location should already be configured
+ // TODO: check for location?
+ l.IsType = srvtypes.TypeSecureInternet
+ return nil
+}
+
+func (l *List) RemoveSecureInternet(identifier string) error {
+ oid := l.SecureInternetHomeServer.HomeOrganizationID
+ if identifier != oid {
+ return errors.Errorf("cannot remove secure internet server: identifier: %s, is not equal to the Org ID: %s", identifier, oid)
+ }
+ // Empty out the struct
+ l.SecureInternetHomeServer = secure.Server{}
+
+ // If the current server is secure internet, reset to unknown
+ if l.IsType == srvtypes.TypeSecureInternet {
+ l.IsType = srvtypes.TypeUnknown
+ }
+ return nil
+}
+
+func (l *List) SetInstituteAccess(srv Server) error {
+ b, err := srv.Base()
+ if err != nil {
+ return err
+ }
+
+ if b.Type != srvtypes.TypeInstituteAccess {
+ return errors.Errorf("not an institute access server, URL: %s, type: %v", b.URL, b.Type)
+ }
+
+ if _, ok := l.InstituteServers.Map[b.URL]; ok {
+ l.InstituteServers.CurrentURL = b.URL
+ l.IsType = srvtypes.TypeInstituteAccess
+ } else {
+ return errors.Errorf("institute access server with URL: %s, is not yet configured", b.URL)
+ }
+ return nil
+}
+
+func (l *List) InstituteAccess(url string) (*institute.Server, error) {
+ if srv, ok := l.InstituteServers.Map[url]; ok {
+ return srv, nil
+ }
+ return nil, errors.Errorf("no institute access server with URL: %s", url)
+}
+
+func (l *List) RemoveInstituteAccess(url string) error {
+ // TODO: Reset current to unknown?
+ return l.InstituteServers.Remove(url)
+}
+
+func (l *List) SetCustom(server Server) error {
+ b, err := server.Base()
+ if err != nil {
+ return err
+ }
+
+ if b.Type != srvtypes.TypeCustom {
+ return errors.New("not a custom server")
+ }
+
+ if _, ok := l.CustomServers.Map[b.URL]; ok {
+ l.CustomServers.CurrentURL = b.URL
+ l.IsType = srvtypes.TypeCustom
+ } else {
+ return errors.Errorf("this server is not yet added as a custom server: %s", b.URL)
+ }
+ return nil
+}
+
+func (l *List) CustomServer(url string) (*institute.Server, error) {
+ if srv, ok := l.CustomServers.Map[url]; ok {
+ return srv, nil
+ }
+ return nil, errors.Errorf("failed to get institute access server - no custom server with URL '%s'", url)
+}
+
+func (l *List) RemoveCustom(url string) error {
+ // TODO: Reset current to unknown?
+ return l.CustomServers.Remove(url)
+}
diff --git a/internal/server/profile.go b/internal/server/profile.go
deleted file mode 100644
index d981421..0000000
--- a/internal/server/profile.go
+++ /dev/null
@@ -1,44 +0,0 @@
-package server
-
-type Profile struct {
- ID string `json:"profile_id"`
- DisplayName string `json:"display_name"`
- VPNProtoList []string `json:"vpn_proto_list"`
- DefaultGateway bool `json:"default_gateway"`
-}
-
-type ProfileListInfo struct {
- ProfileList []Profile `json:"profile_list"`
-}
-
-type ProfileInfo struct {
- Current string `json:"current_profile"`
- Info ProfileListInfo `json:"info"`
-}
-
-func (info ProfileInfo) CurrentProfileIndex() int {
- for i, profile := range info.Info.ProfileList {
- if profile.ID == info.Current {
- return i
- }
- }
- // Default is 'first' profile
- return 0
-}
-
-func (profile *Profile) supportsProtocol(protocol string) bool {
- for _, proto := range profile.VPNProtoList {
- if proto == protocol {
- return true
- }
- }
- return false
-}
-
-func (profile *Profile) SupportsWireguard() bool {
- return profile.supportsProtocol("wireguard")
-}
-
-func (profile *Profile) SupportsOpenVPN() bool {
- return profile.supportsProtocol("openvpn")
-}
diff --git a/internal/server/profile/profile.go b/internal/server/profile/profile.go
new file mode 100644
index 0000000..7a19685
--- /dev/null
+++ b/internal/server/profile/profile.go
@@ -0,0 +1,88 @@
+package profile
+
+import (
+ "github.com/eduvpn/eduvpn-common/types/protocol"
+ "github.com/eduvpn/eduvpn-common/types/server"
+)
+
+type Profile struct {
+ ID string `json:"profile_id"`
+ DisplayName string `json:"display_name"`
+ VPNProtoList []string `json:"vpn_proto_list"`
+ DefaultGateway bool `json:"default_gateway"`
+}
+
+type ListInfo struct {
+ ProfileList []Profile `json:"profile_list"`
+}
+
+type Info struct {
+ Current string `json:"current_profile"`
+ Info ListInfo `json:"info"`
+}
+
+func (info Info) CurrentProfileIndex() int {
+ for i, profile := range info.Info.ProfileList {
+ if profile.ID == info.Current {
+ return i
+ }
+ }
+ // Default is 'first' profile
+ return 0
+}
+
+func (profile *Profile) supportsProtocol(protocol string) bool {
+ for _, proto := range profile.VPNProtoList {
+ if proto == protocol {
+ return true
+ }
+ }
+ return false
+}
+
+func (profile *Profile) SupportsWireguard() bool {
+ return profile.supportsProtocol("wireguard")
+}
+
+func (profile *Profile) SupportsOpenVPN() bool {
+ return profile.supportsProtocol("openvpn")
+}
+
+func (info Info) Supported(wireguardSupport bool) []Profile {
+ var valid []Profile
+ for _, p := range info.Info.ProfileList {
+ // Not a valid profile because it does not support openvpn
+ // Also the client does not support wireguard
+ if !p.SupportsOpenVPN() && !wireguardSupport {
+ continue
+ }
+ valid = append(valid, p)
+ }
+ return valid
+}
+
+func (info Info) Has(id string) bool {
+ for _, p := range info.Info.ProfileList {
+ if p.ID == id {
+ return true
+ }
+ }
+ return false
+}
+
+func (info Info) Public() server.Profiles {
+ m := make(map[string]server.Profile)
+ for _, p := range info.Info.ProfileList {
+ var protocols []protocol.Protocol
+ for _, ps := range p.VPNProtoList {
+ protocols = append(protocols, protocol.New(ps))
+ }
+ m[p.ID] = server.Profile{
+ DisplayName: map[string]string{
+ "en": p.DisplayName,
+ },
+ Protocols: protocols,
+ }
+ }
+ return server.Profiles{Map: m, Current: info.Current}
+}
diff --git a/internal/server/profile_test.go b/internal/server/profile/profile_test.go
index d6a7e9d..e246b5c 100644
--- a/internal/server/profile_test.go
+++ b/internal/server/profile/profile_test.go
@@ -1,4 +1,4 @@
-package server
+package profile
import "testing"
@@ -86,9 +86,9 @@ func Test_CurrentProfileIndex(t *testing.T) {
}
for _, tc := range testCases {
- pri := &ProfileInfo{
+ pri := &Info{
Current: tc.current,
- Info: ProfileListInfo{
+ Info: ListInfo{
ProfileList: tc.profiles,
},
}
diff --git a/internal/server/secure/secure.go b/internal/server/secure/secure.go
new file mode 100644
index 0000000..6fed010
--- /dev/null
+++ b/internal/server/secure/secure.go
@@ -0,0 +1,148 @@
+package secure
+
+import (
+ "context"
+ "github.com/eduvpn/eduvpn-common/internal/oauth"
+ "github.com/eduvpn/eduvpn-common/internal/server/api"
+ "github.com/eduvpn/eduvpn-common/internal/server/base"
+ "github.com/eduvpn/eduvpn-common/internal/util"
+ discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
+ "github.com/eduvpn/eduvpn-common/types/server"
+ "github.com/go-errors/errors"
+)
+
+// Server secure internet server which has its own OAuth tokens
+// It specifies the current location url it is connected to.
+type Server struct {
+ Auth oauth.OAuth `json:"oauth"`
+ DisplayName map[string]string `json:"display_name"`
+
+ // The home server has a list of info for each configured server location
+ BaseMap map[string]*base.Base `json:"base_map"`
+
+ // We have the authorization URL template, the home organization ID and the current location
+ AuthorizationTemplate string `json:"authorization_template"`
+ HomeOrganizationID string `json:"home_organization_id"`
+ CurrentLocation string `json:"current_location"`
+}
+
+func (s *Server) TemplateAuth() func(string) string {
+ return func(authURL string) string {
+ return util.ReplaceWAYF(s.AuthorizationTemplate, authURL, s.HomeOrganizationID)
+ }
+}
+
+func (s *Server) Base() (*base.Base, error) {
+ if s.BaseMap == nil {
+ return nil, errors.Errorf("secure internet map not found")
+ }
+
+ b, ok := s.BaseMap[s.CurrentLocation]
+ if !ok {
+ return nil, errors.Errorf("secure internet base with location '%s' not found", s.CurrentLocation)
+ }
+ return b, nil
+}
+
+func (s *Server) OAuth() *oauth.OAuth {
+ return &s.Auth
+}
+
+func (s *Server) NeedsLocation() bool {
+ if s.CurrentLocation == "" {
+ return true
+ }
+ if len(s.BaseMap) == 0 {
+ return true
+ }
+ return false
+}
+
+func (s *Server) addLocation(ctx context.Context, locSrv *discotypes.Server) (*base.Base, error) {
+ // Initialize the base map if it is non-nil
+ if s.BaseMap == nil {
+ s.BaseMap = make(map[string]*base.Base)
+ }
+
+ // Add the location to the base map
+ b, ok := s.BaseMap[locSrv.CountryCode]
+ if !ok || b == nil {
+ // Create the base to be added to the map
+ b = &base.Base{}
+ b.URL = locSrv.BaseURL
+ b.DisplayName = s.DisplayName
+ b.SupportContact = locSrv.SupportContact
+ b.Type = server.TypeSecureInternet
+ if err := api.Endpoints(ctx, b); err != nil {
+ return nil, err
+ }
+ }
+
+ // Ensure it is in the map
+ s.BaseMap[locSrv.CountryCode] = b
+ return b, nil
+}
+
+func (s *Server) Location(ctx context.Context, locSrv *discotypes.Server) error {
+ if _, err := s.addLocation(ctx, locSrv); err != nil {
+ return err
+ }
+ s.CurrentLocation = locSrv.CountryCode
+ return nil
+}
+
+// Initializes the home server and adds its own location.
+func (s *Server) Init(
+ ctx context.Context,
+ homeOrg *discotypes.Organization, homeLoc *discotypes.Server,
+) error {
+ if s.HomeOrganizationID != homeOrg.OrgID {
+ // New home organisation, clear everything
+ *s = Server{}
+ }
+
+ // Make sure to set the organization ID
+ s.HomeOrganizationID = homeOrg.OrgID
+ s.DisplayName = homeOrg.DisplayName
+
+ // Make sure to set the authorization URL template
+ s.AuthorizationTemplate = homeLoc.AuthenticationURLTemplate
+
+ b, err := s.addLocation(ctx, homeLoc)
+ if err != nil {
+ return err
+ }
+
+ // set the home location as the current
+ err = s.Location(ctx, homeLoc)
+ if err != nil {
+ return err
+ }
+
+ // Set the current location to the home location if there is none
+ if s.CurrentLocation == "" {
+ s.CurrentLocation = homeLoc.CountryCode
+ }
+
+ // Make sure oauth contains our endpoints
+ s.Auth.Init(b.URL, b.Endpoints.API.V3.Authorization, b.Endpoints.API.V3.Token)
+ return nil
+}
+
+func (s *Server) Public() (interface{}, error) {
+ b, err := s.Base()
+ var p server.Profiles
+ dn := s.DisplayName
+ if err == nil {
+ dn = b.DisplayName
+ p = b.Profiles.Public()
+ }
+ return &server.SecureInternet{
+ Server: server.Server{
+ DisplayName: dn,
+ Identifier: s.HomeOrganizationID,
+ Profiles: p,
+ },
+ CountryCode: s.CurrentLocation,
+ }, nil
+}
diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go
deleted file mode 100644
index 4b42303..0000000
--- a/internal/server/secureinternet.go
+++ /dev/null
@@ -1,175 +0,0 @@
-package server
-
-import (
- "github.com/eduvpn/eduvpn-common/internal/discovery"
- "github.com/eduvpn/eduvpn-common/internal/oauth"
- "github.com/eduvpn/eduvpn-common/internal/util"
- discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
- "github.com/go-errors/errors"
-)
-
-// SecureInternetHomeServer secure internet server which has its own OAuth tokens
-// It specifies the current location url it is connected to.
-type SecureInternetHomeServer struct {
- Auth oauth.OAuth `json:"oauth"`
- DisplayName map[string]string `json:"display_name"`
-
- // The home server has a list of info for each configured server location
- BaseMap map[string]*Base `json:"base_map"`
-
- // We have the authorization URL template, the home organization ID and the current location
- AuthorizationTemplate string `json:"authorization_template"`
- HomeOrganizationID string `json:"home_organization_id"`
- CurrentLocation string `json:"current_location"`
-}
-
-func (ss *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) {
- if !ss.HasSecureLocation() {
- return nil, errors.Errorf("no secure internet home server")
- }
- return &ss.SecureInternetHomeServer, nil
-}
-
-func (ss *Servers) SetSecureInternet(server Server) error {
- b, err := server.Base()
- if err != nil {
- return err
- }
-
- if b.Type != "secure_internet" {
- return errors.Errorf("not a secure internet server")
- }
-
- // The location should already be configured
- // TODO: check for location?
- ss.IsType = SecureInternetServerType
- return nil
-}
-
-func (ss *Servers) RemoveSecureInternet() {
- // Empty out the struct
- ss.SecureInternetHomeServer = SecureInternetHomeServer{}
-
- // If the current server is secure internet, default to custom server
- if ss.IsType == SecureInternetServerType {
- ss.IsType = CustomServerType
- }
-}
-
-func (s *SecureInternetHomeServer) TemplateAuth() func(string) string {
- return func(authURL string) string {
- return util.ReplaceWAYF(s.AuthorizationTemplate, authURL, s.HomeOrganizationID)
- }
-}
-
-func (s *SecureInternetHomeServer) Base() (*Base, error) {
- if s.BaseMap == nil {
- return nil, errors.Errorf("secure internet map not found")
- }
-
- b, ok := s.BaseMap[s.CurrentLocation]
- if !ok {
- return nil, errors.Errorf("secure internet base with location '%s' not found", s.CurrentLocation)
- }
- return b, nil
-}
-
-func (s *SecureInternetHomeServer) OAuth() *oauth.OAuth {
- return &s.Auth
-}
-
-func (ss *Servers) HasSecureLocation() bool {
- return ss.SecureInternetHomeServer.CurrentLocation != ""
-}
-
-func (s *SecureInternetHomeServer) addLocation(locSrv *discotypes.Server) (*Base, error) {
- // Initialize the base map if it is non-nil
- if s.BaseMap == nil {
- s.BaseMap = make(map[string]*Base)
- }
-
- // Add the location to the base map
- b, ok := s.BaseMap[locSrv.CountryCode]
- if !ok || b == nil {
- // Create the base to be added to the map
- b = &Base{}
- b.URL = locSrv.BaseURL
- b.DisplayName = s.DisplayName
- b.SupportContact = locSrv.SupportContact
- b.Type = "secure_internet"
- if err := b.InitializeEndpoints(); err != nil {
- return nil, err
- }
- }
-
- // Ensure it is in the map
- s.BaseMap[locSrv.CountryCode] = b
- return b, nil
-}
-
-// Initializes the home server and adds its own location.
-func (s *SecureInternetHomeServer) init(
- homeOrg *discotypes.Organization, homeLoc *discotypes.Server,
-) error {
- if s.HomeOrganizationID != homeOrg.OrgID {
- // New home organisation, clear everything
- *s = SecureInternetHomeServer{}
- }
-
- // Make sure to set the organization ID
- s.HomeOrganizationID = homeOrg.OrgID
- s.DisplayName = homeOrg.DisplayName
-
- // Make sure to set the authorization URL template
- s.AuthorizationTemplate = homeLoc.AuthenticationURLTemplate
-
- b, err := s.addLocation(homeLoc)
- if err != nil {
- return err
- }
-
- // Set the current location to the home location if there is none
- if s.CurrentLocation == "" {
- s.CurrentLocation = homeLoc.CountryCode
- }
-
- // Make sure oauth contains our endpoints
- s.Auth.Init(b.URL, b.Endpoints.API.V3.Authorization, b.Endpoints.API.V3.Token)
- return nil
-}
-
-func (s *SecureInternetHomeServer) RefreshEndpoints(disco *discovery.Discovery) error {
- // update OAuth for home server
- auth := s.OAuth()
- if auth != nil && s.HomeOrganizationID != "" {
- _, srv, err := disco.SecureHomeArgs(s.HomeOrganizationID)
- if err != nil {
- return err
- }
- if hb, ok := s.BaseMap[srv.CountryCode]; ok && hb != nil {
- err := hb.InitializeEndpoints()
- if err != nil {
- return err
- }
- auth.BaseAuthorizationURL = hb.Endpoints.API.V3.Authorization
- auth.TokenURL = hb.Endpoints.API.V3.Token
- }
- // already updated, return
- if srv.CountryCode == s.CurrentLocation {
- return nil
- }
- }
-
- // refresh the current location endpoints
- // Re-initialize the endpoints
- b, err := s.Base()
- if err != nil {
- return err
- }
-
- err = b.InitializeEndpoints()
- if err != nil {
- return err
- }
- return nil
-}
diff --git a/internal/server/server.go b/internal/server/server.go
index 4bd8766..e7229c5 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -1,23 +1,21 @@
package server
import (
+ "context"
"os"
"time"
"github.com/eduvpn/eduvpn-common/internal/discovery"
"github.com/eduvpn/eduvpn-common/internal/oauth"
+ "github.com/eduvpn/eduvpn-common/internal/server/api"
+ "github.com/eduvpn/eduvpn-common/internal/server/base"
+ "github.com/eduvpn/eduvpn-common/internal/server/profile"
"github.com/eduvpn/eduvpn-common/internal/wireguard"
+ "github.com/eduvpn/eduvpn-common/types/protocol"
+ srvtypes "github.com/eduvpn/eduvpn-common/types/server"
"github.com/go-errors/errors"
)
-type Type int8
-
-const (
- CustomServerType Type = iota
- InstituteAccessServerType
- SecureInternetServerType
-)
-
type Server interface {
OAuth() *oauth.OAuth
@@ -25,27 +23,13 @@ type Server interface {
TemplateAuth() func(string) string
// Base returns the server base
- Base() (*Base, error)
+ Base() (*base.Base, error)
- // RefreshEndpoints
- RefreshEndpoints(*discovery.Discovery) error
-}
+ // NeedsLocation checks if the server needs a secure internet location
+ NeedsLocation() bool
-type EndpointList struct {
- API string `json:"api_endpoint"`
- Authorization string `json:"authorization_endpoint"`
- Token string `json:"token_endpoint"`
-}
-
-type EndpointsVersions struct {
- V2 EndpointList `json:"http://eduvpn.org/api#2"`
- V3 EndpointList `json:"http://eduvpn.org/api#3"`
-}
-
-// Endpoints defines the json format for /.well-known/vpn-user-portal".
-type Endpoints struct {
- API EndpointsVersions `json:"api"`
- V string `json:"v"`
+ // Public returns the representation that will be passed over the CGO barrier
+ Public() (interface{}, error)
}
func UpdateTokens(srv Server, t oauth.Token) {
@@ -56,12 +40,12 @@ func OAuthURL(srv Server, name string) (string, error) {
return srv.OAuth().AuthURL(name, srv.TemplateAuth())
}
-func OAuthExchange(srv Server) error {
- return srv.OAuth().Exchange()
+func OAuthExchange(ctx context.Context, srv Server) error {
+ return srv.OAuth().Exchange(ctx)
}
-func HeaderToken(srv Server) (string, error) {
- return srv.OAuth().AccessToken()
+func HeaderToken(ctx context.Context, srv Server) (string, error) {
+ return srv.OAuth().AccessToken(ctx)
}
func MarkTokenExpired(srv Server) {
@@ -72,16 +56,13 @@ func MarkTokensForRenew(srv Server) {
srv.OAuth().SetTokenRenew()
}
-func NeedsRelogin(srv Server) bool {
- _, err := HeaderToken(srv)
+func NeedsRelogin(ctx context.Context, srv Server) bool {
+ // TODO: this error can be a context cancel
+ _, err := HeaderToken(ctx, srv)
return err != nil
}
-func CancelOAuth(srv Server) {
- srv.OAuth().Cancel()
-}
-
-func CurrentProfile(srv Server) (*Profile, error) {
+func CurrentProfile(srv Server) (*profile.Profile, error) {
b, err := srv.Base()
if err != nil {
return nil, err
@@ -96,19 +77,31 @@ func CurrentProfile(srv Server) (*Profile, error) {
return nil, errors.Errorf("profile not found: " + pID)
}
-func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) {
+func ValidProfiles(srv Server, wireguardSupport bool) (*[]profile.Profile, error) {
// No error wrapping here otherwise we wrap it too much
b, err := srv.Base()
if err != nil {
return nil, err
}
- ps := b.ValidProfiles(wireguardSupport)
- if len(ps.Info.ProfileList) == 0 {
+ ps := b.Profiles.Supported(wireguardSupport)
+ if len(ps) == 0 {
return nil, errors.Errorf("no profiles found with supported protocols")
}
return &ps, nil
}
+func Profile(srv Server, id string) error {
+ b, err := srv.Base()
+ if err != nil {
+ return err
+ }
+ if !b.Profiles.Has(id) {
+ return errors.Errorf("no profile available with id: %s", id)
+ }
+ b.Profiles.Current = id
+ return nil
+}
+
type ConfigData struct {
// The configuration
Config string
@@ -120,7 +113,18 @@ type ConfigData struct {
Tokens oauth.Token
}
-func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) {
+// Public gets the public data from the types package
+// dg specifies if this config is default gateway
+func (c *ConfigData) Public(dg bool) srvtypes.Configuration {
+ return srvtypes.Configuration{
+ VPNConfig: c.Config,
+ Protocol: protocol.New(c.Type),
+ DefaultGateway: dg,
+ Tokens: c.Tokens.Public(),
+ }
+}
+
+func wireguardGetConfig(ctx context.Context, srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) {
b, err := srv.Base()
if err != nil {
return nil, err
@@ -133,7 +137,7 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi
}
pub := key.PublicKey().String()
- cfg, proto, exp, err := APIConnectWireguard(srv, pID, pub, preferTCP, openVPNSupport)
+ cfg, proto, exp, err := api.ConnectWireguard(ctx, b, srv.OAuth(), pID, pub, preferTCP, openVPNSupport)
if err != nil {
return nil, err
}
@@ -159,13 +163,13 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi
return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil
}
-func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) {
+func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigData, error) {
b, err := srv.Base()
if err != nil {
return nil, err
}
pid := b.Profiles.Current
- cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP)
+ cfg, exp, err := api.ConnectOpenVPN(ctx, b, srv.OAuth(), pid, preferTCP)
if err != nil {
return nil, err
}
@@ -184,15 +188,14 @@ func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) {
return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil
}
-func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) {
- // Get new profiles using the info call
- // This does not override the current profile
- err := APIInfo(srv)
+func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bool, error) {
+ b, err := srv.Base()
if err != nil {
return false, err
}
-
- b, err := srv.Base()
+ // Get new profiles using the info call
+ // This does not override the current profile
+ err = api.Info(ctx, b, srv.OAuth())
if err != nil {
return false, err
}
@@ -225,7 +228,18 @@ func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) {
return true, nil
}
-func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) {
+func RefreshEndpoints(ctx context.Context, srv Server) error {
+ // Re-initialize the endpoints
+ // TODO: Make this a warning instead?
+ b, err := srv.Base()
+ if err != nil {
+ return err
+ }
+
+ return api.Endpoints(ctx, b)
+}
+
+func Config(ctx context.Context, server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) {
p, err := CurrentProfile(server)
if err != nil {
return nil, err
@@ -250,10 +264,10 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData,
case wg:
// A wireguard connect call needs to generate a wireguard key and add it to the config
// Also the server could send back an OpenVPN config if it supports OpenVPN
- cfg, err = wireguardGetConfig(server, preferTCP, ovpn)
+ cfg, err = wireguardGetConfig(ctx, server, preferTCP, ovpn)
// The config only supports OpenVPN
case ovpn:
- cfg, err = openVPNGetConfig(server, preferTCP)
+ cfg, err = openVPNGetConfig(ctx, server, preferTCP)
// The config supports no available protocol because the profile only supports WireGuard but the client doesn't
default:
return nil, errors.New("no supported protocol found")
@@ -267,6 +281,10 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData,
return cfg, err
}
-func Disconnect(server Server) error {
- return APIDisconnect(server)
+func Disconnect(ctx context.Context, server Server) error {
+ b, err := server.Base()
+ if err != nil {
+ return err
+ }
+ return api.Disconnect(ctx, b, server.OAuth())
}
diff --git a/internal/server/servers.go b/internal/server/servers.go
deleted file mode 100644
index 60c993d..0000000
--- a/internal/server/servers.go
+++ /dev/null
@@ -1,121 +0,0 @@
-package server
-
-import (
- discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
- "github.com/go-errors/errors"
-)
-
-// TODO: Have a dedicated type for custom servers
-type Servers struct {
- // A custom server is just an institute access server under the hood
- CustomServers InstituteAccessServers `json:"custom_servers"`
- InstituteServers InstituteAccessServers `json:"institute_servers"`
- SecureInternetHomeServer SecureInternetHomeServer `json:"secure_internet_home"`
- IsType Type `json:"is_secure_internet"`
-}
-
-// HasSecureInternet returns whether or not we have a secure internet server added
-func (ss *Servers) HasSecureInternet() bool {
- return len(ss.SecureInternetHomeServer.BaseMap) > 0
-}
-
-func (ss *Servers) AddSecureInternet(
- secureOrg *discotypes.Organization,
- secureServer *discotypes.Server,
-) (Server, error) {
- // If we have specified an organization ID
- // We also need to get an authorization template
- err := ss.SecureInternetHomeServer.init(secureOrg, secureServer)
- if err != nil {
- return nil, err
- }
-
- ss.IsType = SecureInternetServerType
- return &ss.SecureInternetHomeServer, nil
-}
-
-func (ss *Servers) GetCurrentServer() (Server, error) {
- // TODO(jwijenbergh): Almost certainly the return type should be pointer (*Server)
- if ss.IsType == SecureInternetServerType {
- if !ss.HasSecureLocation() {
- return nil, errors.Errorf("ss.IsType = %v; ss.HasSecureLocation() = false", ss.IsType)
- }
- return &ss.SecureInternetHomeServer, nil
- }
-
- srvs := &ss.InstituteServers
-
- if ss.IsType == CustomServerType {
- srvs = &ss.CustomServers
- }
- if srvs.Map == nil {
- return nil, errors.Errorf("srvs.Map is nil")
- }
-
- srv, ok := srvs.Map[srvs.CurrentURL]
- if !ok || srv == nil {
- return nil, errors.Errorf("server not found")
- }
- return srv, nil
-}
-
-func (ss *Servers) addInstituteAndCustom(
- discoServer *discotypes.Server,
- isCustom bool,
-) (Server, error) {
- URL := discoServer.BaseURL
- srvs := &ss.InstituteServers
- srvType := InstituteAccessServerType
-
- if isCustom {
- srvs = &ss.CustomServers
- srvType = CustomServerType
- }
-
- if srvs.Map == nil {
- srvs.Map = make(map[string]*InstituteAccessServer)
- }
-
- srv, ok := srvs.Map[URL]
-
- // initialize the server if it doesn't exist yet
- if !ok {
- srv = &InstituteAccessServer{}
- }
-
- if err := srv.init(URL, discoServer.DisplayName, discoServer.Type, discoServer.SupportContact); err != nil {
- return nil, err
- }
- srvs.Map[URL] = srv
- ss.IsType = srvType
- return srv, nil
-}
-
-func (ss *Servers) AddInstituteAccessServer(
- instituteServer *discotypes.Server,
-) (Server, error) {
- return ss.addInstituteAndCustom(instituteServer, false)
-}
-
-func (ss *Servers) AddCustomServer(
- customServer *discotypes.Server,
-) (Server, error) {
- return ss.addInstituteAndCustom(customServer, true)
-}
-
-func (ss *Servers) GetSecureLocation() string {
- return ss.SecureInternetHomeServer.CurrentLocation
-}
-
-func (ss *Servers) SetSecureLocation(
- chosenLocationServer *discotypes.Server,
-) error {
- // Make sure to add the current location
-
- if _, err := ss.SecureInternetHomeServer.addLocation(chosenLocationServer); err != nil {
- return err
- }
-
- ss.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode
- return nil
-}