From a23c3e61c5d89ef67973891b5b3a176c06e1b174 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Wed, 12 Apr 2023 22:52:49 +0200 Subject: Refactor: Split internal server into multiple packages - Pass contexts - Have separate packages for e.g. custom, institute and secure - internet servers, profiles.... - Return types from the public ./types package with a Public() method --- internal/server/api.go | 243 -------------------------------- internal/server/api/api.go | 217 ++++++++++++++++++++++++++++ internal/server/api/api_test.go | 134 ++++++++++++++++++ internal/server/api_test.go | 130 ----------------- internal/server/base.go | 112 --------------- internal/server/base/base.go | 90 ++++++++++++ internal/server/custom.go | 35 ----- internal/server/custom/custom.go | 31 ++++ internal/server/endpoints/endpoints.go | 53 +++++++ internal/server/institute/institute.go | 106 ++++++++++++++ internal/server/instituteaccess.go | 114 --------------- internal/server/list.go | 179 +++++++++++++++++++++++ internal/server/profile.go | 44 ------ internal/server/profile/profile.go | 88 ++++++++++++ internal/server/profile/profile_test.go | 100 +++++++++++++ internal/server/profile_test.go | 100 ------------- internal/server/secure/secure.go | 148 +++++++++++++++++++ internal/server/secureinternet.go | 175 ----------------------- internal/server/server.go | 130 +++++++++-------- internal/server/servers.go | 121 ---------------- 20 files changed, 1220 insertions(+), 1130 deletions(-) delete mode 100644 internal/server/api.go create mode 100644 internal/server/api/api.go create mode 100644 internal/server/api/api_test.go delete mode 100644 internal/server/api_test.go delete mode 100644 internal/server/base.go create mode 100644 internal/server/base/base.go delete mode 100644 internal/server/custom.go create mode 100644 internal/server/custom/custom.go create mode 100644 internal/server/endpoints/endpoints.go create mode 100644 internal/server/institute/institute.go delete mode 100644 internal/server/instituteaccess.go create mode 100644 internal/server/list.go delete mode 100644 internal/server/profile.go create mode 100644 internal/server/profile/profile.go create mode 100644 internal/server/profile/profile_test.go delete mode 100644 internal/server/profile_test.go create mode 100644 internal/server/secure/secure.go delete mode 100644 internal/server/secureinternet.go delete mode 100644 internal/server/servers.go diff --git a/internal/server/api.go b/internal/server/api.go deleted file mode 100644 index 546c02a..0000000 --- a/internal/server/api.go +++ /dev/null @@ -1,243 +0,0 @@ -package server - -import ( - "encoding/json" - "fmt" - "net/http" - "net/url" - "path" - "time" - - httpw "github.com/eduvpn/eduvpn-common/internal/http" - "github.com/eduvpn/eduvpn-common/internal/log" - "github.com/go-errors/errors" -) - -func validateEndpoints(endpoints Endpoints) error { - v3 := endpoints.API.V3 - pAPI, err := url.Parse(v3.API) - if err != nil { - return errors.WrapPrefix(err, "failed to parse API endpoint", 0) - } - pAuth, err := url.Parse(v3.Authorization) - if err != nil { - return errors.WrapPrefix(err, "failed to parse API authorization endpoint", 0) - } - pToken, err := url.Parse(v3.Token) - if err != nil { - return errors.WrapPrefix(err, "failed to parse API token endpoint", 0) - } - if pAPI.Scheme != pAuth.Scheme { - return errors.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme) - } - if pAPI.Scheme != pToken.Scheme { - return errors.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme) - } - if pAPI.Host != pAuth.Host { - return errors.Errorf("API host: '%v', is not equal to authorization host: '%v'", pAPI.Host, pAuth.Host) - } - if pAPI.Host != pToken.Host { - return errors.Errorf("API host: '%v', is not equal to token host: '%v'", pAPI.Host, pToken.Host) - } - return nil -} - -func APIGetEndpoints(baseURL string, client *httpw.Client) (*Endpoints, error) { - uStr, err := httpw.JoinURLPath(baseURL, "/.well-known/vpn-user-portal") - if err != nil { - return nil, err - } - if client == nil { - client = httpw.NewClient() - } - _, body, err := client.Get(uStr) - if err != nil { - return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) - } - - ep := Endpoints{} - if err = json.Unmarshal(body, &ep); err != nil { - return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) - } - err = validateEndpoints(ep) - if err != nil { - return nil, err - } - - return &ep, nil -} - -func apiAuthorized( - srv Server, - method string, - endpoint string, - opts *httpw.OptionalParams, -) (http.Header, []byte, error) { - // Ensure optional is not nil as we will fill it with headers - if opts == nil { - opts = &httpw.OptionalParams{} - } - errorMessage := "failed API authorized" - b, err := srv.Base() - if err != nil { - return nil, nil, errors.WrapPrefix(err, errorMessage, 0) - } - - // Join the paths - u, err := url.Parse(b.Endpoints.API.V3.API) - if err != nil { - return nil, nil, errors.WrapPrefix(err, errorMessage, 0) - } - u.Path = path.Join(u.Path, endpoint) - - // Make sure the tokens are valid, this will return an error if re-login is needed - t, err := HeaderToken(srv) - if err != nil { - return nil, nil, errors.WrapPrefix(err, errorMessage, 0) - } - - key := "Authorization" - val := fmt.Sprintf("Bearer %s", t) - if opts.Headers != nil { - opts.Headers.Add(key, val) - } else { - opts.Headers = http.Header{key: {val}} - } - - // Create a client if it doesn't exist - if b.httpClient == nil { - b.httpClient = httpw.NewClient() - } - return b.httpClient.Do(method, u.String(), opts) -} - -func apiAuthorizedRetry( - srv Server, - method string, - endpoint string, - opts *httpw.OptionalParams, -) (http.Header, []byte, error) { - h, body, err := apiAuthorized(srv, method, endpoint, opts) - if err == nil { - return h, body, nil - } - - statErr := &httpw.StatusError{} - // Only retry authorized if we get an HTTP 401 - if errors.As(err, &statErr) && statErr.Status == 401 { - log.Logger.Debugf("Got a 401 error after HTTP method: %s, endpoint: %s. Marking token as expired...", method, endpoint) - // Mark the token as expired and retry, so we trigger the refresh flow - MarkTokenExpired(srv) - h, body, err = apiAuthorized(srv, method, endpoint, opts) - } - return h, body, err -} - -func APIInfo(srv Server) error { - _, body, err := apiAuthorizedRetry(srv, http.MethodGet, "/info", nil) - if err != nil { - return err - } - profiles := ProfileInfo{} - if err = json.Unmarshal(body, &profiles); err != nil { - return errors.WrapPrefix(err, "failed API /info", 0) - } - - b, err := srv.Base() - if err != nil { - return err - } - - // Store the profiles and make sure that the current profile is not overwritten - prev := b.Profiles.Current - b.Profiles = profiles - b.Profiles.Current = prev - return nil -} - -// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 -func boolToYesNo(preferTCP bool) string { - if preferTCP { - return "yes" - } - return "no" -} - -func APIConnectWireguard( - srv Server, - profileID string, - pubkey string, - preferTCP bool, - openVPNSupport bool, -) (string, string, time.Time, error) { - hdrs := http.Header{ - "content-type": {"application/x-www-form-urlencoded"}, - "accept": {"application/x-wireguard-profile"}, - } - - // This profile also supports OpenVPN - // Indicate that we also accept OpenVPN profiles - if openVPNSupport { - hdrs.Add("accept", "application/x-openvpn-profile") - } - - vals := url.Values{ - "profile_id": {profileID}, - "public_key": {pubkey}, - "prefer_tcp": {boolToYesNo(preferTCP)}, - } - h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect", - &httpw.OptionalParams{Headers: hdrs, Body: vals}) - if err != nil { - return "", "", time.Time{}, err - } - - exp := h.Get("expires") - - expTime, err := http.ParseTime(exp) - if err != nil { - return "", "", time.Time{}, errors.WrapPrefix(err, "failed obtaining a WireGuard configuration", 0) - } - - contentH := h.Get("content-type") - content := "openvpn" - if contentH == "application/x-wireguard-profile" { - content = "wireguard" - } - - return string(body), content, expTime, nil -} - -func APIConnectOpenVPN(srv Server, profileID string, preferTCP bool) (string, time.Time, error) { - hdrs := http.Header{ - "content-type": {"application/x-www-form-urlencoded"}, - "accept": {"application/x-openvpn-profile"}, - } - - vals := url.Values{ - "profile_id": {profileID}, - "prefer_tcp": {boolToYesNo(preferTCP)}, - } - - h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect", - &httpw.OptionalParams{Headers: hdrs, Body: vals}) - if err != nil { - return "", time.Time{}, err - } - - expH := h.Get("expires") - expT, err := http.ParseTime(expH) - if err != nil { - return "", time.Time{}, errors.WrapPrefix(err, "failed obtaining an OpenVPN configuration", 0) - } - - return string(body), expT, nil -} - -// APIDisconnect disconnects from the API. -func APIDisconnect(server Server) error { - // The timeout is a bit lower here such that this does not take a too long time for disconnecting - // Clients may wish to retry this - _, _, err := apiAuthorized(server, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second}) - return err -} diff --git a/internal/server/api/api.go b/internal/server/api/api.go new file mode 100644 index 0000000..9ad6f2d --- /dev/null +++ b/internal/server/api/api.go @@ -0,0 +1,217 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "path" + "time" + + httpw "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/internal/log" + "github.com/eduvpn/eduvpn-common/internal/oauth" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/server/endpoints" + "github.com/eduvpn/eduvpn-common/internal/server/profile" + "github.com/go-errors/errors" +) + +func Endpoints(ctx context.Context, b *base.Base) error { + uStr, err := httpw.JoinURLPath(b.URL, "/.well-known/vpn-user-portal") + if err != nil { + return err + } + if b.HTTPClient == nil { + b.HTTPClient = httpw.NewClient() + } + _, body, err := b.HTTPClient.Get(ctx, uStr) + if err != nil { + return errors.WrapPrefix(err, "failed getting server endpoints", 0) + } + + ep := endpoints.Endpoints{} + if err = json.Unmarshal(body, &ep); err != nil { + return errors.WrapPrefix(err, "failed getting server endpoints", 0) + } + err = ep.Validate() + if err != nil { + return err + } + + b.Endpoints = ep + return nil +} + +func authorized( + ctx context.Context, + b *base.Base, + oauth *oauth.OAuth, + method string, + endpoint string, + opts *httpw.OptionalParams, +) (http.Header, []byte, error) { + // Ensure optional is not nil as we will fill it with headers + if opts == nil { + opts = &httpw.OptionalParams{} + } + errorMessage := "failed API authorized" + + // Join the paths + u, err := url.Parse(b.Endpoints.API.V3.API) + if err != nil { + return nil, nil, errors.WrapPrefix(err, errorMessage, 0) + } + u.Path = path.Join(u.Path, endpoint) + + // Make sure the tokens are valid, this will return an error if re-login is needed + t, err := oauth.AccessToken(ctx) + if err != nil { + return nil, nil, errors.WrapPrefix(err, errorMessage, 0) + } + + key := "Authorization" + val := fmt.Sprintf("Bearer %s", t) + if opts.Headers != nil { + opts.Headers.Add(key, val) + } else { + opts.Headers = http.Header{key: {val}} + } + + // Create a client if it doesn't exist + if b.HTTPClient == nil { + b.HTTPClient = httpw.NewClient() + } + return b.HTTPClient.Do(ctx, method, u.String(), opts) +} + +func authorizedRetry( + ctx context.Context, + b *base.Base, + auth *oauth.OAuth, + method string, + endpoint string, + opts *httpw.OptionalParams, +) (http.Header, []byte, error) { + h, body, err := authorized(ctx, b, auth, method, endpoint, opts) + if err == nil { + return h, body, nil + } + + statErr := &httpw.StatusError{} + // Only retry authorized if we get an HTTP 401 + if errors.As(err, &statErr) && statErr.Status == 401 { + log.Logger.Debugf("Got a 401 error after HTTP method: %s, endpoint: %s. Marking token as expired...", method, endpoint) + // Mark the token as expired and retry, so we trigger the refresh flow + auth.SetTokenExpired() + h, body, err = authorized(ctx, b, auth, method, endpoint, opts) + } + return h, body, err +} + +func Info(ctx context.Context, b *base.Base, auth *oauth.OAuth) error { + _, body, err := authorizedRetry(ctx, b, auth, http.MethodGet, "/info", nil) + if err != nil { + return err + } + profiles := profile.Info{} + if err = json.Unmarshal(body, &profiles); err != nil { + return errors.WrapPrefix(err, "failed API /info", 0) + } + + // Store the profiles and make sure that the current profile is not overwritten + prev := b.Profiles.Current + b.Profiles = profiles + b.Profiles.Current = prev + return nil +} + +// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 +func boolToYesNo(preferTCP bool) string { + if preferTCP { + return "yes" + } + return "no" +} + +func ConnectWireguard( + ctx context.Context, + b *base.Base, + auth *oauth.OAuth, + profileID string, + pubkey string, + preferTCP bool, + openVPNSupport bool, +) (string, string, time.Time, error) { + hdrs := http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + "accept": {"application/x-wireguard-profile"}, + } + + // This profile also supports OpenVPN + // Indicate that we also accept OpenVPN profiles + if openVPNSupport { + hdrs.Add("accept", "application/x-openvpn-profile") + } + + vals := url.Values{ + "profile_id": {profileID}, + "public_key": {pubkey}, + "prefer_tcp": {boolToYesNo(preferTCP)}, + } + h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect", + &httpw.OptionalParams{Headers: hdrs, Body: vals}) + if err != nil { + return "", "", time.Time{}, err + } + + exp := h.Get("expires") + + expTime, err := http.ParseTime(exp) + if err != nil { + return "", "", time.Time{}, errors.WrapPrefix(err, "failed obtaining a WireGuard configuration", 0) + } + + contentH := h.Get("content-type") + content := "openvpn" + if contentH == "application/x-wireguard-profile" { + content = "wireguard" + } + + return string(body), content, expTime, nil +} + +func ConnectOpenVPN(ctx context.Context, b *base.Base, auth *oauth.OAuth, profileID string, preferTCP bool) (string, time.Time, error) { + hdrs := http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + "accept": {"application/x-openvpn-profile"}, + } + + vals := url.Values{ + "profile_id": {profileID}, + "prefer_tcp": {boolToYesNo(preferTCP)}, + } + + h, body, err := authorizedRetry(ctx, b, auth, http.MethodPost, "/connect", + &httpw.OptionalParams{Headers: hdrs, Body: vals}) + if err != nil { + return "", time.Time{}, err + } + + expH := h.Get("expires") + expT, err := http.ParseTime(expH) + if err != nil { + return "", time.Time{}, errors.WrapPrefix(err, "failed obtaining an OpenVPN configuration", 0) + } + + return string(body), expT, nil +} + +// Disconnect disconnects the VPN using the API. +func Disconnect(ctx context.Context, b *base.Base, auth *oauth.OAuth) error { + // The timeout is a bit lower here such that this does not take a too long time for disconnecting + // Clients may wish to retry this + _, _, err := authorized(ctx, b, auth, http.MethodPost, "/disconnect", &httpw.OptionalParams{Timeout: 5 * time.Second}) + return err +} diff --git a/internal/server/api/api_test.go b/internal/server/api/api_test.go new file mode 100644 index 0000000..7509a30 --- /dev/null +++ b/internal/server/api/api_test.go @@ -0,0 +1,134 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/server/endpoints" + "github.com/eduvpn/eduvpn-common/internal/test" + "github.com/go-errors/errors" +) + +func getErrorMsg(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func compareEndpoints(ep1 endpoints.Endpoints, ep2 endpoints.Endpoints) bool { + v3_1 := ep1.API.V3 + v3_2 := ep2.API.V3 + return v3_1.API == v3_2.API && v3_1.Authorization == v3_2.Authorization && v3_1.Token == v3_2.Token +} + +func Test_APIGetEndpoints(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello!") + }) + hs := &test.HandlerSet{} + hs.SetHandler(handler) + s := test.NewServer(hs) + defer s.Close() + + c, err := s.Client() + if err != nil { + t.Fatalf("failed to get client for test server endpoints: %v", err) + } + + testCases := []struct { + epl endpoints.List + err error + }{ + { + epl: endpoints.List{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: nil, + }, + { + epl: endpoints.List{ + API: "http://example.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API scheme: 'http', is not equal to authorization scheme: 'https'"), + }, + { + epl: endpoints.List{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "ftp://example.com/3", + }, + err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"), + }, + { + epl: endpoints.List{ + API: "https://malicious.com/1", + Authorization: "https://example.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API host: 'malicious.com', is not equal to authorization host: 'example.com'"), + }, + { + epl: endpoints.List{ + API: "https://example.com/1", + Authorization: "https://example.com/2", + Token: "https://malicious.com/3", + }, + err: errors.New("API host: 'example.com', is not equal to token host: 'malicious.com'"), + }, + { + epl: endpoints.List{ + API: "https://example.com/1", + Authorization: "https://malicious.com/2", + Token: "https://example.com/3", + }, + err: errors.New("API host: 'example.com', is not equal to authorization host: 'malicious.com'"), + }, + } + + for _, tc := range testCases { + ep := &endpoints.Endpoints{ + API: endpoints.Versions{ + V3: tc.epl, + }, + } + // Update the handler + hs.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + jsonStr, err := json.Marshal(ep) + if err != nil { + t.Fatalf("failed to marshal JSON for test case: %v, err: %v", tc, err) + } + + fmt.Fprintln(w, string(jsonStr)) + })) + b := &base.Base{ + URL: s.URL, + HTTPClient: c, + } + err = Endpoints(context.Background(), b) + if getErrorMsg(err) != getErrorMsg(tc.err) { + t.Fatalf("Errors not equal, want err: %v, got: %v", tc.err, err) + } + // The error was not nil, continue because endpoints should not be compared + if tc.err != nil { + continue + } + if ep == nil { + t.Fatalf("No test case endpoints") + } + // if no error then the endpoints should be equal + if !compareEndpoints(*ep, b.Endpoints) { + t.Fatalf("Endpoints are not equal, got: %v, want: %v", b.Endpoints, ep) + } + } +} diff --git a/internal/server/api_test.go b/internal/server/api_test.go deleted file mode 100644 index b1e3550..0000000 --- a/internal/server/api_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package server - -import ( - "encoding/json" - "fmt" - "net/http" - "testing" - - "github.com/eduvpn/eduvpn-common/internal/test" - "github.com/go-errors/errors" -) - -func getErrorMsg(err error) string { - if err == nil { - return "" - } - return err.Error() -} - -func compareEndpoints(ep1 Endpoints, ep2 Endpoints) bool { - v3_1 := ep1.API.V3 - v3_2 := ep2.API.V3 - return v3_1.API == v3_2.API && v3_1.Authorization == v3_2.Authorization && v3_1.Token == v3_2.Token -} - -func Test_APIGetEndpoints(t *testing.T) { - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "Hello!") - }) - hs := &test.HandlerSet{} - hs.SetHandler(handler) - s := test.NewServer(hs) - defer s.Close() - - c, err := s.Client() - if err != nil { - t.Fatalf("failed to get client for test server endpoints: %v", err) - } - - testCases := []struct { - epl EndpointList - err error - }{ - { - epl: EndpointList{ - API: "https://example.com/1", - Authorization: "https://example.com/2", - Token: "https://example.com/3", - }, - err: nil, - }, - { - epl: EndpointList{ - API: "http://example.com/1", - Authorization: "https://example.com/2", - Token: "https://example.com/3", - }, - err: errors.New("API scheme: 'http', is not equal to authorization scheme: 'https'"), - }, - { - epl: EndpointList{ - API: "https://example.com/1", - Authorization: "https://example.com/2", - Token: "ftp://example.com/3", - }, - err: errors.New("API scheme: 'https', is not equal to token scheme: 'ftp'"), - }, - { - epl: EndpointList{ - API: "https://malicious.com/1", - Authorization: "https://example.com/2", - Token: "https://example.com/3", - }, - err: errors.New("API host: 'malicious.com', is not equal to authorization host: 'example.com'"), - }, - { - epl: EndpointList{ - API: "https://example.com/1", - Authorization: "https://example.com/2", - Token: "https://malicious.com/3", - }, - err: errors.New("API host: 'example.com', is not equal to token host: 'malicious.com'"), - }, - { - epl: EndpointList{ - API: "https://example.com/1", - Authorization: "https://malicious.com/2", - Token: "https://example.com/3", - }, - err: errors.New("API host: 'example.com', is not equal to authorization host: 'malicious.com'"), - }, - } - - for _, tc := range testCases { - ep := &Endpoints{ - API: EndpointsVersions{ - V3: tc.epl, - }, - } - // Update the handler - hs.SetHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - jsonStr, err := json.Marshal(ep) - if err != nil { - t.Fatalf("failed to marshal JSON for test case: %v, err: %v", tc, err) - } - - fmt.Fprintln(w, string(jsonStr)) - })) - gotEP, err := APIGetEndpoints(s.URL, c) - if getErrorMsg(err) != getErrorMsg(tc.err) { - t.Fatalf("Errors not equal, want err: %v, got: %v", tc.err, err) - } - // The error was not nil, continue because endpoints should not be compared - if tc.err != nil { - continue - } - if ep == nil { - t.Fatalf("No test case endpoints") - } - if gotEP == nil { - t.Fatalf("Got no endpoints for nil error") - } - // if no error then the endpoints should be equal - if !compareEndpoints(*ep, *gotEP) { - t.Fatalf("Endpoints are not equal, got: %v, want: %v", gotEP, ep) - } - } -} diff --git a/internal/server/base.go b/internal/server/base.go deleted file mode 100644 index c7a9adc..0000000 --- a/internal/server/base.go +++ /dev/null @@ -1,112 +0,0 @@ -package server - -import ( - "time" - - "github.com/eduvpn/eduvpn-common/internal/http" -) - -// Base is the base type for servers. -type Base struct { - URL string `json:"base_url"` - DisplayName map[string]string `json:"display_name"` - SupportContact []string `json:"support_contact"` - Endpoints Endpoints `json:"endpoints"` - Profiles ProfileInfo `json:"profiles"` - StartTime time.Time `json:"start_time"` - EndTime time.Time `json:"expire_time"` - Type string `json:"server_type"` - httpClient *http.Client -} - -func (b *Base) InitializeEndpoints() error { - ep, err := APIGetEndpoints(b.URL, b.httpClient) - if err != nil { - return err - } - b.Endpoints = *ep - return nil -} - -func (b *Base) ValidProfiles(wireguardSupport bool) ProfileInfo { - var valid []Profile - for _, p := range b.Profiles.Info.ProfileList { - // Not a valid profile because it does not support openvpn - // Also the client does not support wireguard - if !p.SupportsOpenVPN() && !wireguardSupport { - continue - } - valid = append(valid, p) - } - return ProfileInfo{ - Current: b.Profiles.Current, - Info: ProfileListInfo{ProfileList: valid}, - } -} - -// RenewButtonTime returns the time when the renew button should be shown for the server -// Implemented according to: https://github.com/eduvpn/documentation/blob/cdf4d054f7652d74e4192494e8bb0e21040e46ac/API.md#session-expiry -func (b *Base) RenewButtonTime() int64 { - d := b.EndTime.Sub(b.StartTime) - - // If the time is less than 24 hours (a day left), show it when 30 minutes have passed or on expired if less than 30 minutes - dayl := time.Duration(24 * time.Hour) - if d < dayl { - // Get the minimum time to add, 30 minutes or on expired - m := time.Duration(30 * time.Minute) - // The total delta time is larger, return that we should show the button after 30 minutes - if d > m { - return b.StartTime.Add(30 * time.Minute).Unix() - } - // Just show it on expired - return b.StartTime.Add(d).Unix() - } - - // Else just show it when 24 hours is left - // This is the delta minus 24 hours left as that's how long it takes for a day to be left in the expiry - // We thus add this to the start time - tillDay := d - dayl - t := b.StartTime.Add(tillDay) - return t.Unix() -} - -func (b *Base) CountdownTime() int64 { - d := b.EndTime.Sub(b.StartTime) - - dayl := time.Duration(24 * time.Hour) - - // This is just the last 24 hours - // if less than or equal to 24 hours, immediately - if d <= dayl { - return b.StartTime.Unix() - } - - tillDay := d - dayl - t := b.StartTime.Add(tillDay) - return t.Unix() -} - -func (b *Base) NotificationTimes() []int64 { - last := []time.Duration{ - time.Duration(0), - time.Duration(1 * time.Hour), - time.Duration(2 * time.Hour), - time.Duration(4 * time.Hour), - } - - var t []int64 - - d := b.EndTime.Sub(b.StartTime) - for _, l := range last { - // If the notification remaining time is more than the total delta, continue - if l > d { - continue - } - // calculating the time till a notification must happen - tillN := d - l - // Get absolute time when this notification must be shown by adding the delta - c := b.StartTime.Add(tillN) - t = append(t, c.Unix()) - } - return t -} diff --git a/internal/server/base/base.go b/internal/server/base/base.go new file mode 100644 index 0000000..d483dad --- /dev/null +++ b/internal/server/base/base.go @@ -0,0 +1,90 @@ +package base + +import ( + "time" + + "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/internal/server/endpoints" + "github.com/eduvpn/eduvpn-common/internal/server/profile" + "github.com/eduvpn/eduvpn-common/types/server" +) + +// Base is the base type for servers. +type Base struct { + URL string `json:"base_url"` + DisplayName map[string]string `json:"display_name"` + SupportContact []string `json:"support_contact"` + Endpoints endpoints.Endpoints `json:"endpoints"` + Profiles profile.Info `json:"profiles"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"expire_time"` + Type server.Type `json:"server_type"` + HTTPClient *http.Client `json:"-"` +} + +// RenewButtonTime returns the time when the renew button should be shown for the server +// Implemented according to: https://github.com/eduvpn/documentation/blob/cdf4d054f7652d74e4192494e8bb0e21040e46ac/API.md#session-expiry +func (b *Base) RenewButtonTime() int64 { + d := b.EndTime.Sub(b.StartTime) + + // If the time is less than 24 hours (a day left), show it when 30 minutes have passed or on expired if less than 30 minutes + dayl := time.Duration(24 * time.Hour) + if d < dayl { + // Get the minimum time to add, 30 minutes or on expired + m := time.Duration(30 * time.Minute) + // The total delta time is larger, return that we should show the button after 30 minutes + if d > m { + return b.StartTime.Add(30 * time.Minute).Unix() + } + // Just show it on expired + return b.StartTime.Add(d).Unix() + } + + // Else just show it when 24 hours is left + // This is the delta minus 24 hours left as that's how long it takes for a day to be left in the expiry + // We thus add this to the start time + tillDay := d - dayl + t := b.StartTime.Add(tillDay) + return t.Unix() +} + +func (b *Base) CountdownTime() int64 { + d := b.EndTime.Sub(b.StartTime) + + dayl := time.Duration(24 * time.Hour) + + // This is just the last 24 hours + // if less than or equal to 24 hours, immediately + if d <= dayl { + return b.StartTime.Unix() + } + + tillDay := d - dayl + t := b.StartTime.Add(tillDay) + return t.Unix() +} + +func (b *Base) NotificationTimes() []int64 { + last := []time.Duration{ + time.Duration(0), + time.Duration(1 * time.Hour), + time.Duration(2 * time.Hour), + time.Duration(4 * time.Hour), + } + + var t []int64 + + d := b.EndTime.Sub(b.StartTime) + for _, l := range last { + // If the notification remaining time is more than the total delta, continue + if l > d { + continue + } + // calculating the time till a notification must happen + tillN := d - l + // Get absolute time when this notification must be shown by adding the delta + c := b.StartTime.Add(tillN) + t = append(t, c.Unix()) + } + return t +} diff --git a/internal/server/custom.go b/internal/server/custom.go deleted file mode 100644 index 6171e24..0000000 --- a/internal/server/custom.go +++ /dev/null @@ -1,35 +0,0 @@ -package server - -import ( - "github.com/go-errors/errors" -) - -func (ss *Servers) SetCustomServer(server Server) error { - b, err := server.Base() - if err != nil { - return err - } - - if b.Type != "custom_server" { - return errors.New("not a custom server") - } - - if _, ok := ss.CustomServers.Map[b.URL]; ok { - ss.CustomServers.CurrentURL = b.URL - ss.IsType = CustomServerType - } else { - return errors.Errorf("this server is not yet added as a custom server: %s", b.URL) - } - return nil -} - -func (ss *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) { - if srv, ok := ss.CustomServers.Map[url]; ok { - return srv, nil - } - return nil, errors.Errorf("failed to get institute access server - no custom server with URL '%s'", url) -} - -func (ss *Servers) RemoveCustomServer(url string) { - ss.CustomServers.Remove(url) -} diff --git a/internal/server/custom/custom.go b/internal/server/custom/custom.go new file mode 100644 index 0000000..14a72a5 --- /dev/null +++ b/internal/server/custom/custom.go @@ -0,0 +1,31 @@ +package custom + +import ( + "context" + + "github.com/eduvpn/eduvpn-common/internal/server/api" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/server/institute" + "github.com/eduvpn/eduvpn-common/types/server" +) + +type ( + Server = institute.Server + Servers = institute.Servers +) + +func New(ctx context.Context, url string) (*Server, error) { + b := base.Base{ + URL: url, + DisplayName: map[string]string{"en": url}, + Type: server.TypeCustom, + } + if err := api.Endpoints(ctx, &b); err != nil { + return nil, err + } + API := b.Endpoints.API.V3 + + s := &Server{Basic: b} + s.Auth.Init(url, API.Authorization, API.Token) + return s, nil +} diff --git a/internal/server/endpoints/endpoints.go b/internal/server/endpoints/endpoints.go new file mode 100644 index 0000000..75bca55 --- /dev/null +++ b/internal/server/endpoints/endpoints.go @@ -0,0 +1,53 @@ +package endpoints + +import ( + "net/url" + + "github.com/go-errors/errors" +) + +type List struct { + API string `json:"api_endpoint"` + Authorization string `json:"authorization_endpoint"` + Token string `json:"token_endpoint"` +} + +type Versions struct { + V2 List `json:"http://eduvpn.org/api#2"` + V3 List `json:"http://eduvpn.org/api#3"` +} + +// Endpoints defines the json format for /.well-known/vpn-user-portal". +type Endpoints struct { + API Versions `json:"api"` + V string `json:"v"` +} + +func (e Endpoints) Validate() error { + v3 := e.API.V3 + pAPI, err := url.Parse(v3.API) + if err != nil { + return errors.WrapPrefix(err, "failed to parse API endpoint", 0) + } + pAuth, err := url.Parse(v3.Authorization) + if err != nil { + return errors.WrapPrefix(err, "failed to parse API authorization endpoint", 0) + } + pToken, err := url.Parse(v3.Token) + if err != nil { + return errors.WrapPrefix(err, "failed to parse API token endpoint", 0) + } + if pAPI.Scheme != pAuth.Scheme { + return errors.Errorf("API scheme: '%v', is not equal to authorization scheme: '%v'", pAPI.Scheme, pAuth.Scheme) + } + if pAPI.Scheme != pToken.Scheme { + return errors.Errorf("API scheme: '%v', is not equal to token scheme: '%v'", pAPI.Scheme, pToken.Scheme) + } + if pAPI.Host != pAuth.Host { + return errors.Errorf("API host: '%v', is not equal to authorization host: '%v'", pAPI.Host, pAuth.Host) + } + if pAPI.Host != pToken.Host { + return errors.Errorf("API host: '%v', is not equal to token host: '%v'", pAPI.Host, pToken.Host) + } + return nil +} diff --git a/internal/server/institute/institute.go b/internal/server/institute/institute.go new file mode 100644 index 0000000..ada1977 --- /dev/null +++ b/internal/server/institute/institute.go @@ -0,0 +1,106 @@ +package institute + +import ( + "context" + + "github.com/eduvpn/eduvpn-common/internal/oauth" + "github.com/eduvpn/eduvpn-common/internal/server/api" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/types/server" + "github.com/go-errors/errors" +) + +type Server struct { + // An instute access server has its own OAuth + Auth oauth.OAuth `json:"oauth"` + + // Embed the server base + Basic base.Base `json:"base"` +} + +type Servers struct { + Map map[string]*Server `json:"map"` + CurrentURL string `json:"current_url"` +} + +func New( + ctx context.Context, + url string, + name map[string]string, + supportContact []string, +) (*Server, error) { + b := base.Base{ + URL: url, + DisplayName: name, + SupportContact: supportContact, + Type: server.TypeInstituteAccess, + } + if err := api.Endpoints(ctx, &b); err != nil { + return nil, err + } + API := b.Endpoints.API.V3 + + s := &Server{Basic: b} + s.Auth.Init(url, API.Authorization, API.Token) + return s, nil +} + +func (s *Servers) Current() (*Server, error) { + if s.Map == nil { + return nil, errors.Errorf("No map is found when getting the current server") + } + + srv, ok := s.Map[s.CurrentURL] + if !ok || srv == nil { + return nil, errors.Errorf("server not found") + } + return srv, nil +} + +func (s *Servers) Remove(url string) error { + // check if it is in the map to begin with + if _, ok := s.Map[url]; ok { + delete(s.Map, url) + } else { + return errors.Errorf("cannot remove URL: %v, not found in list", url) + } + + // Reset the current url + if s.CurrentURL == url { + s.CurrentURL = "" + } + return nil +} + +func (s *Servers) Add(srv *Server) { + if s.Map == nil { + s.Map = make(map[string]*Server) + } + s.Map[srv.Basic.URL] = srv +} + +func (s *Server) TemplateAuth() func(string) string { + return func(authURL string) string { + return authURL + } +} + +func (s *Server) Base() (*base.Base, error) { + return &s.Basic, nil +} + +func (s *Server) OAuth() *oauth.OAuth { + return &s.Auth +} + +func (s *Server) NeedsLocation() bool { + return false +} + +func (s *Server) Public() (interface{}, error) { + return &server.Server{ + DisplayName: s.Basic.DisplayName, + Identifier: s.Basic.URL, + Profiles: s.Basic.Profiles.Public(), + }, nil +} diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go deleted file mode 100644 index ebafb26..0000000 --- a/internal/server/instituteaccess.go +++ /dev/null @@ -1,114 +0,0 @@ -package server - -import ( - "github.com/eduvpn/eduvpn-common/internal/discovery" - "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/go-errors/errors" -) - -type InstituteAccessServer struct { - // An instute access server has its own OAuth - Auth oauth.OAuth `json:"oauth"` - - // Embed the server base - Basic Base `json:"base"` -} - -type InstituteAccessServers struct { - Map map[string]*InstituteAccessServer `json:"map"` - CurrentURL string `json:"current_url"` -} - -func (ss *Servers) SetInstituteAccess(srv Server) error { - b, err := srv.Base() - if err != nil { - return err - } - - if b.Type != "institute_access" { - return errors.Errorf("not an institute access server, URL: %s, type: %s", b.URL, b.Type) - } - - if _, ok := ss.InstituteServers.Map[b.URL]; ok { - ss.InstituteServers.CurrentURL = b.URL - ss.IsType = InstituteAccessServerType - } else { - return errors.Errorf("institute access server with URL: %s, is not yet configured", b.URL) - } - return nil -} - -func (ss *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) { - if srv, ok := ss.InstituteServers.Map[url]; ok { - return srv, nil - } - return nil, errors.Errorf("no institute access server with URL: %s", url) -} - -func (ss *Servers) RemoveInstituteAccess(url string) { - ss.InstituteServers.Remove(url) -} - -func (iass *InstituteAccessServers) Remove(url string) { - // Reset the current url - if iass.CurrentURL == url { - iass.CurrentURL = "" - } - - // Delete the url from the map - delete(iass.Map, url) -} - -func (ias *InstituteAccessServer) TemplateAuth() func(string) string { - return func(authURL string) string { - return authURL - } -} - -func (ias *InstituteAccessServer) Base() (*Base, error) { - return &ias.Basic, nil -} - -func (ias *InstituteAccessServer) OAuth() *oauth.OAuth { - return &ias.Auth -} - -func (ias *InstituteAccessServer) RefreshEndpoints(_ *discovery.Discovery) error { - // Re-initialize the endpoints - b, err := ias.Base() - if err != nil { - return err - } - - err = b.InitializeEndpoints() - if err != nil { - return err - } - - // update OAuth - auth := ias.OAuth() - if auth != nil { - auth.BaseAuthorizationURL = b.Endpoints.API.V3.Authorization - auth.TokenURL = b.Endpoints.API.V3.Token - } - return nil -} - -func (ias *InstituteAccessServer) init( - url string, - name map[string]string, - srvType string, - supportContact []string, -) error { - ias.Basic.URL = url - ias.Basic.DisplayName = name - ias.Basic.SupportContact = supportContact - ias.Basic.Type = srvType - err := ias.Basic.InitializeEndpoints() - if err != nil { - return err - } - API := ias.Basic.Endpoints.API.V3 - ias.Auth.Init(url, API.Authorization, API.Token) - return nil -} diff --git a/internal/server/list.go b/internal/server/list.go new file mode 100644 index 0000000..2660102 --- /dev/null +++ b/internal/server/list.go @@ -0,0 +1,179 @@ +package server + +import ( + "context" + + "github.com/eduvpn/eduvpn-common/internal/server/custom" + "github.com/eduvpn/eduvpn-common/internal/server/institute" + "github.com/eduvpn/eduvpn-common/internal/server/secure" + discotypes "github.com/eduvpn/eduvpn-common/types/discovery" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" + "github.com/go-errors/errors" +) + +type List struct { + CustomServers custom.Servers `json:"custom_servers"` + InstituteServers institute.Servers `json:"institute_servers"` + SecureInternetHomeServer secure.Server `json:"secure_internet_home"` + IsType srvtypes.Type `json:"is_secure_internet"` +} + +// HasSecureInternet returns whether or not we have a secure internet server added +func (l *List) HasSecureInternet() bool { + return len(l.SecureInternetHomeServer.BaseMap) > 0 +} + +func (l *List) HasSecureLocation() bool { + return l.SecureInternetHomeServer.CurrentLocation != "" +} + +func (l *List) Current() (Server, error) { + if l.IsType == srvtypes.TypeUnknown { + return nil, errors.New("no current server") + } + if l.IsType == srvtypes.TypeSecureInternet { + if !l.HasSecureLocation() { + return nil, errors.Errorf("Current server is secure internet but there is no secure internet location: %v", l.IsType) + } + return &l.SecureInternetHomeServer, nil + } + + if l.IsType == srvtypes.TypeCustom { + return l.CustomServers.Current() + } + return l.InstituteServers.Current() +} + +func (l *List) AddCustom(ctx context.Context, url string) (Server, error) { + srv, err := custom.New(ctx, url) + if err != nil { + return nil, err + } + l.CustomServers.Add(srv) + return srv, nil +} + +func (l *List) AddInstituteAccess(ctx context.Context, discoServer *discotypes.Server) (Server, error) { + srv, err := institute.New(ctx, discoServer.BaseURL, discoServer.DisplayName, discoServer.SupportContact) + if err != nil { + return nil, err + } + l.InstituteServers.Add(srv) + return srv, nil +} + +func (l *List) AddSecureInternet( + ctx context.Context, + secureOrg *discotypes.Organization, + secureServer *discotypes.Server, +) (*secure.Server, error) { + // If we have specified an organization ID + // We also need to get an authorization template + err := l.SecureInternetHomeServer.Init(ctx, secureOrg, secureServer) + if err != nil { + return nil, err + } + + l.IsType = srvtypes.TypeSecureInternet + return &l.SecureInternetHomeServer, nil +} + +func (l *List) SecureInternet(identifier string) (*secure.Server, error) { + if l.SecureInternetHomeServer.HomeOrganizationID != identifier { + return nil, errors.Errorf("no secure internet home server with identifier: %s", identifier) + } + return &l.SecureInternetHomeServer, nil +} + +func (l *List) SetSecureInternet(server Server) error { + b, err := server.Base() + if err != nil { + return err + } + + if b.Type != srvtypes.TypeSecureInternet { + return errors.New("not a secure internet server") + } + + // The location should already be configured + // TODO: check for location? + l.IsType = srvtypes.TypeSecureInternet + return nil +} + +func (l *List) RemoveSecureInternet(identifier string) error { + oid := l.SecureInternetHomeServer.HomeOrganizationID + if identifier != oid { + return errors.Errorf("cannot remove secure internet server: identifier: %s, is not equal to the Org ID: %s", identifier, oid) + } + // Empty out the struct + l.SecureInternetHomeServer = secure.Server{} + + // If the current server is secure internet, reset to unknown + if l.IsType == srvtypes.TypeSecureInternet { + l.IsType = srvtypes.TypeUnknown + } + return nil +} + +func (l *List) SetInstituteAccess(srv Server) error { + b, err := srv.Base() + if err != nil { + return err + } + + if b.Type != srvtypes.TypeInstituteAccess { + return errors.Errorf("not an institute access server, URL: %s, type: %v", b.URL, b.Type) + } + + if _, ok := l.InstituteServers.Map[b.URL]; ok { + l.InstituteServers.CurrentURL = b.URL + l.IsType = srvtypes.TypeInstituteAccess + } else { + return errors.Errorf("institute access server with URL: %s, is not yet configured", b.URL) + } + return nil +} + +func (l *List) InstituteAccess(url string) (*institute.Server, error) { + if srv, ok := l.InstituteServers.Map[url]; ok { + return srv, nil + } + return nil, errors.Errorf("no institute access server with URL: %s", url) +} + +func (l *List) RemoveInstituteAccess(url string) error { + // TODO: Reset current to unknown? + return l.InstituteServers.Remove(url) +} + +func (l *List) SetCustom(server Server) error { + b, err := server.Base() + if err != nil { + return err + } + + if b.Type != srvtypes.TypeCustom { + return errors.New("not a custom server") + } + + if _, ok := l.CustomServers.Map[b.URL]; ok { + l.CustomServers.CurrentURL = b.URL + l.IsType = srvtypes.TypeCustom + } else { + return errors.Errorf("this server is not yet added as a custom server: %s", b.URL) + } + return nil +} + +func (l *List) CustomServer(url string) (*institute.Server, error) { + if srv, ok := l.CustomServers.Map[url]; ok { + return srv, nil + } + return nil, errors.Errorf("failed to get institute access server - no custom server with URL '%s'", url) +} + +func (l *List) RemoveCustom(url string) error { + // TODO: Reset current to unknown? + return l.CustomServers.Remove(url) +} diff --git a/internal/server/profile.go b/internal/server/profile.go deleted file mode 100644 index d981421..0000000 --- a/internal/server/profile.go +++ /dev/null @@ -1,44 +0,0 @@ -package server - -type Profile struct { - ID string `json:"profile_id"` - DisplayName string `json:"display_name"` - VPNProtoList []string `json:"vpn_proto_list"` - DefaultGateway bool `json:"default_gateway"` -} - -type ProfileListInfo struct { - ProfileList []Profile `json:"profile_list"` -} - -type ProfileInfo struct { - Current string `json:"current_profile"` - Info ProfileListInfo `json:"info"` -} - -func (info ProfileInfo) CurrentProfileIndex() int { - for i, profile := range info.Info.ProfileList { - if profile.ID == info.Current { - return i - } - } - // Default is 'first' profile - return 0 -} - -func (profile *Profile) supportsProtocol(protocol string) bool { - for _, proto := range profile.VPNProtoList { - if proto == protocol { - return true - } - } - return false -} - -func (profile *Profile) SupportsWireguard() bool { - return profile.supportsProtocol("wireguard") -} - -func (profile *Profile) SupportsOpenVPN() bool { - return profile.supportsProtocol("openvpn") -} diff --git a/internal/server/profile/profile.go b/internal/server/profile/profile.go new file mode 100644 index 0000000..7a19685 --- /dev/null +++ b/internal/server/profile/profile.go @@ -0,0 +1,88 @@ +package profile + +import ( + "github.com/eduvpn/eduvpn-common/types/protocol" + "github.com/eduvpn/eduvpn-common/types/server" +) + +type Profile struct { + ID string `json:"profile_id"` + DisplayName string `json:"display_name"` + VPNProtoList []string `json:"vpn_proto_list"` + DefaultGateway bool `json:"default_gateway"` +} + +type ListInfo struct { + ProfileList []Profile `json:"profile_list"` +} + +type Info struct { + Current string `json:"current_profile"` + Info ListInfo `json:"info"` +} + +func (info Info) CurrentProfileIndex() int { + for i, profile := range info.Info.ProfileList { + if profile.ID == info.Current { + return i + } + } + // Default is 'first' profile + return 0 +} + +func (profile *Profile) supportsProtocol(protocol string) bool { + for _, proto := range profile.VPNProtoList { + if proto == protocol { + return true + } + } + return false +} + +func (profile *Profile) SupportsWireguard() bool { + return profile.supportsProtocol("wireguard") +} + +func (profile *Profile) SupportsOpenVPN() bool { + return profile.supportsProtocol("openvpn") +} + +func (info Info) Supported(wireguardSupport bool) []Profile { + var valid []Profile + for _, p := range info.Info.ProfileList { + // Not a valid profile because it does not support openvpn + // Also the client does not support wireguard + if !p.SupportsOpenVPN() && !wireguardSupport { + continue + } + valid = append(valid, p) + } + return valid +} + +func (info Info) Has(id string) bool { + for _, p := range info.Info.ProfileList { + if p.ID == id { + return true + } + } + return false +} + +func (info Info) Public() server.Profiles { + m := make(map[string]server.Profile) + for _, p := range info.Info.ProfileList { + var protocols []protocol.Protocol + for _, ps := range p.VPNProtoList { + protocols = append(protocols, protocol.New(ps)) + } + m[p.ID] = server.Profile{ + DisplayName: map[string]string{ + "en": p.DisplayName, + }, + Protocols: protocols, + } + } + return server.Profiles{Map: m, Current: info.Current} +} diff --git a/internal/server/profile/profile_test.go b/internal/server/profile/profile_test.go new file mode 100644 index 0000000..e246b5c --- /dev/null +++ b/internal/server/profile/profile_test.go @@ -0,0 +1,100 @@ +package profile + +import "testing" + +func Test_CurrentProfileIndex(t *testing.T) { + testCases := []struct { + profiles []Profile + current string + index int + }{ + { + profiles: []Profile{ + { + ID: "a", + DisplayName: "b", + VPNProtoList: []string{"openvpn", "wireguard"}, + }, + }, + current: "a", + index: 0, + }, + { + profiles: []Profile{ + { + ID: "a", + DisplayName: "a", + VPNProtoList: []string{"openvpn", "wireguard"}, + }, + { + ID: "b", + DisplayName: "b", + VPNProtoList: []string{"openvpn", "wireguard"}, + }, + }, + current: "b", + index: 1, + }, + { + profiles: []Profile{ + { + ID: "a", + DisplayName: "a", + VPNProtoList: []string{"openvpn", "wireguard"}, + }, + { + ID: "b", + DisplayName: "b", + VPNProtoList: []string{"openvpn", "wireguard"}, + }, + }, + current: "", + index: 0, + }, + { + profiles: []Profile{ + { + ID: "a", + DisplayName: "a", + VPNProtoList: []string{"openvpn", "wireguard"}, + }, + { + ID: "b", + DisplayName: "b", + VPNProtoList: []string{"openvpn", "wireguard"}, + }, + }, + current: "", + index: 0, + }, + { + profiles: []Profile{ + { + ID: "a", + DisplayName: "a", + VPNProtoList: []string{"openvpn", "wireguard"}, + }, + { + ID: "b", + DisplayName: "b", + VPNProtoList: []string{"openvpn", "wireguard"}, + }, + }, + current: "idonotexist", + index: 0, + }, + } + + for _, tc := range testCases { + pri := &Info{ + Current: tc.current, + Info: ListInfo{ + ProfileList: tc.profiles, + }, + } + got := pri.CurrentProfileIndex() + if got != tc.index { + t.Fatalf("failed getting profile index, got: '%v', want: '%v'", got, tc.index) + } + } +} diff --git a/internal/server/profile_test.go b/internal/server/profile_test.go deleted file mode 100644 index d6a7e9d..0000000 --- a/internal/server/profile_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package server - -import "testing" - -func Test_CurrentProfileIndex(t *testing.T) { - testCases := []struct { - profiles []Profile - current string - index int - }{ - { - profiles: []Profile{ - { - ID: "a", - DisplayName: "b", - VPNProtoList: []string{"openvpn", "wireguard"}, - }, - }, - current: "a", - index: 0, - }, - { - profiles: []Profile{ - { - ID: "a", - DisplayName: "a", - VPNProtoList: []string{"openvpn", "wireguard"}, - }, - { - ID: "b", - DisplayName: "b", - VPNProtoList: []string{"openvpn", "wireguard"}, - }, - }, - current: "b", - index: 1, - }, - { - profiles: []Profile{ - { - ID: "a", - DisplayName: "a", - VPNProtoList: []string{"openvpn", "wireguard"}, - }, - { - ID: "b", - DisplayName: "b", - VPNProtoList: []string{"openvpn", "wireguard"}, - }, - }, - current: "", - index: 0, - }, - { - profiles: []Profile{ - { - ID: "a", - DisplayName: "a", - VPNProtoList: []string{"openvpn", "wireguard"}, - }, - { - ID: "b", - DisplayName: "b", - VPNProtoList: []string{"openvpn", "wireguard"}, - }, - }, - current: "", - index: 0, - }, - { - profiles: []Profile{ - { - ID: "a", - DisplayName: "a", - VPNProtoList: []string{"openvpn", "wireguard"}, - }, - { - ID: "b", - DisplayName: "b", - VPNProtoList: []string{"openvpn", "wireguard"}, - }, - }, - current: "idonotexist", - index: 0, - }, - } - - for _, tc := range testCases { - pri := &ProfileInfo{ - Current: tc.current, - Info: ProfileListInfo{ - ProfileList: tc.profiles, - }, - } - got := pri.CurrentProfileIndex() - if got != tc.index { - t.Fatalf("failed getting profile index, got: '%v', want: '%v'", got, tc.index) - } - } -} diff --git a/internal/server/secure/secure.go b/internal/server/secure/secure.go new file mode 100644 index 0000000..6fed010 --- /dev/null +++ b/internal/server/secure/secure.go @@ -0,0 +1,148 @@ +package secure + +import ( + "context" + "github.com/eduvpn/eduvpn-common/internal/oauth" + "github.com/eduvpn/eduvpn-common/internal/server/api" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/util" + discotypes "github.com/eduvpn/eduvpn-common/types/discovery" + "github.com/eduvpn/eduvpn-common/types/server" + "github.com/go-errors/errors" +) + +// Server secure internet server which has its own OAuth tokens +// It specifies the current location url it is connected to. +type Server struct { + Auth oauth.OAuth `json:"oauth"` + DisplayName map[string]string `json:"display_name"` + + // The home server has a list of info for each configured server location + BaseMap map[string]*base.Base `json:"base_map"` + + // We have the authorization URL template, the home organization ID and the current location + AuthorizationTemplate string `json:"authorization_template"` + HomeOrganizationID string `json:"home_organization_id"` + CurrentLocation string `json:"current_location"` +} + +func (s *Server) TemplateAuth() func(string) string { + return func(authURL string) string { + return util.ReplaceWAYF(s.AuthorizationTemplate, authURL, s.HomeOrganizationID) + } +} + +func (s *Server) Base() (*base.Base, error) { + if s.BaseMap == nil { + return nil, errors.Errorf("secure internet map not found") + } + + b, ok := s.BaseMap[s.CurrentLocation] + if !ok { + return nil, errors.Errorf("secure internet base with location '%s' not found", s.CurrentLocation) + } + return b, nil +} + +func (s *Server) OAuth() *oauth.OAuth { + return &s.Auth +} + +func (s *Server) NeedsLocation() bool { + if s.CurrentLocation == "" { + return true + } + if len(s.BaseMap) == 0 { + return true + } + return false +} + +func (s *Server) addLocation(ctx context.Context, locSrv *discotypes.Server) (*base.Base, error) { + // Initialize the base map if it is non-nil + if s.BaseMap == nil { + s.BaseMap = make(map[string]*base.Base) + } + + // Add the location to the base map + b, ok := s.BaseMap[locSrv.CountryCode] + if !ok || b == nil { + // Create the base to be added to the map + b = &base.Base{} + b.URL = locSrv.BaseURL + b.DisplayName = s.DisplayName + b.SupportContact = locSrv.SupportContact + b.Type = server.TypeSecureInternet + if err := api.Endpoints(ctx, b); err != nil { + return nil, err + } + } + + // Ensure it is in the map + s.BaseMap[locSrv.CountryCode] = b + return b, nil +} + +func (s *Server) Location(ctx context.Context, locSrv *discotypes.Server) error { + if _, err := s.addLocation(ctx, locSrv); err != nil { + return err + } + s.CurrentLocation = locSrv.CountryCode + return nil +} + +// Initializes the home server and adds its own location. +func (s *Server) Init( + ctx context.Context, + homeOrg *discotypes.Organization, homeLoc *discotypes.Server, +) error { + if s.HomeOrganizationID != homeOrg.OrgID { + // New home organisation, clear everything + *s = Server{} + } + + // Make sure to set the organization ID + s.HomeOrganizationID = homeOrg.OrgID + s.DisplayName = homeOrg.DisplayName + + // Make sure to set the authorization URL template + s.AuthorizationTemplate = homeLoc.AuthenticationURLTemplate + + b, err := s.addLocation(ctx, homeLoc) + if err != nil { + return err + } + + // set the home location as the current + err = s.Location(ctx, homeLoc) + if err != nil { + return err + } + + // Set the current location to the home location if there is none + if s.CurrentLocation == "" { + s.CurrentLocation = homeLoc.CountryCode + } + + // Make sure oauth contains our endpoints + s.Auth.Init(b.URL, b.Endpoints.API.V3.Authorization, b.Endpoints.API.V3.Token) + return nil +} + +func (s *Server) Public() (interface{}, error) { + b, err := s.Base() + var p server.Profiles + dn := s.DisplayName + if err == nil { + dn = b.DisplayName + p = b.Profiles.Public() + } + return &server.SecureInternet{ + Server: server.Server{ + DisplayName: dn, + Identifier: s.HomeOrganizationID, + Profiles: p, + }, + CountryCode: s.CurrentLocation, + }, nil +} diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go deleted file mode 100644 index 4b42303..0000000 --- a/internal/server/secureinternet.go +++ /dev/null @@ -1,175 +0,0 @@ -package server - -import ( - "github.com/eduvpn/eduvpn-common/internal/discovery" - "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/internal/util" - discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - "github.com/go-errors/errors" -) - -// SecureInternetHomeServer secure internet server which has its own OAuth tokens -// It specifies the current location url it is connected to. -type SecureInternetHomeServer struct { - Auth oauth.OAuth `json:"oauth"` - DisplayName map[string]string `json:"display_name"` - - // The home server has a list of info for each configured server location - BaseMap map[string]*Base `json:"base_map"` - - // We have the authorization URL template, the home organization ID and the current location - AuthorizationTemplate string `json:"authorization_template"` - HomeOrganizationID string `json:"home_organization_id"` - CurrentLocation string `json:"current_location"` -} - -func (ss *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) { - if !ss.HasSecureLocation() { - return nil, errors.Errorf("no secure internet home server") - } - return &ss.SecureInternetHomeServer, nil -} - -func (ss *Servers) SetSecureInternet(server Server) error { - b, err := server.Base() - if err != nil { - return err - } - - if b.Type != "secure_internet" { - return errors.Errorf("not a secure internet server") - } - - // The location should already be configured - // TODO: check for location? - ss.IsType = SecureInternetServerType - return nil -} - -func (ss *Servers) RemoveSecureInternet() { - // Empty out the struct - ss.SecureInternetHomeServer = SecureInternetHomeServer{} - - // If the current server is secure internet, default to custom server - if ss.IsType == SecureInternetServerType { - ss.IsType = CustomServerType - } -} - -func (s *SecureInternetHomeServer) TemplateAuth() func(string) string { - return func(authURL string) string { - return util.ReplaceWAYF(s.AuthorizationTemplate, authURL, s.HomeOrganizationID) - } -} - -func (s *SecureInternetHomeServer) Base() (*Base, error) { - if s.BaseMap == nil { - return nil, errors.Errorf("secure internet map not found") - } - - b, ok := s.BaseMap[s.CurrentLocation] - if !ok { - return nil, errors.Errorf("secure internet base with location '%s' not found", s.CurrentLocation) - } - return b, nil -} - -func (s *SecureInternetHomeServer) OAuth() *oauth.OAuth { - return &s.Auth -} - -func (ss *Servers) HasSecureLocation() bool { - return ss.SecureInternetHomeServer.CurrentLocation != "" -} - -func (s *SecureInternetHomeServer) addLocation(locSrv *discotypes.Server) (*Base, error) { - // Initialize the base map if it is non-nil - if s.BaseMap == nil { - s.BaseMap = make(map[string]*Base) - } - - // Add the location to the base map - b, ok := s.BaseMap[locSrv.CountryCode] - if !ok || b == nil { - // Create the base to be added to the map - b = &Base{} - b.URL = locSrv.BaseURL - b.DisplayName = s.DisplayName - b.SupportContact = locSrv.SupportContact - b.Type = "secure_internet" - if err := b.InitializeEndpoints(); err != nil { - return nil, err - } - } - - // Ensure it is in the map - s.BaseMap[locSrv.CountryCode] = b - return b, nil -} - -// Initializes the home server and adds its own location. -func (s *SecureInternetHomeServer) init( - homeOrg *discotypes.Organization, homeLoc *discotypes.Server, -) error { - if s.HomeOrganizationID != homeOrg.OrgID { - // New home organisation, clear everything - *s = SecureInternetHomeServer{} - } - - // Make sure to set the organization ID - s.HomeOrganizationID = homeOrg.OrgID - s.DisplayName = homeOrg.DisplayName - - // Make sure to set the authorization URL template - s.AuthorizationTemplate = homeLoc.AuthenticationURLTemplate - - b, err := s.addLocation(homeLoc) - if err != nil { - return err - } - - // Set the current location to the home location if there is none - if s.CurrentLocation == "" { - s.CurrentLocation = homeLoc.CountryCode - } - - // Make sure oauth contains our endpoints - s.Auth.Init(b.URL, b.Endpoints.API.V3.Authorization, b.Endpoints.API.V3.Token) - return nil -} - -func (s *SecureInternetHomeServer) RefreshEndpoints(disco *discovery.Discovery) error { - // update OAuth for home server - auth := s.OAuth() - if auth != nil && s.HomeOrganizationID != "" { - _, srv, err := disco.SecureHomeArgs(s.HomeOrganizationID) - if err != nil { - return err - } - if hb, ok := s.BaseMap[srv.CountryCode]; ok && hb != nil { - err := hb.InitializeEndpoints() - if err != nil { - return err - } - auth.BaseAuthorizationURL = hb.Endpoints.API.V3.Authorization - auth.TokenURL = hb.Endpoints.API.V3.Token - } - // already updated, return - if srv.CountryCode == s.CurrentLocation { - return nil - } - } - - // refresh the current location endpoints - // Re-initialize the endpoints - b, err := s.Base() - if err != nil { - return err - } - - err = b.InitializeEndpoints() - if err != nil { - return err - } - return nil -} diff --git a/internal/server/server.go b/internal/server/server.go index 4bd8766..e7229c5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,23 +1,21 @@ package server import ( + "context" "os" "time" "github.com/eduvpn/eduvpn-common/internal/discovery" "github.com/eduvpn/eduvpn-common/internal/oauth" + "github.com/eduvpn/eduvpn-common/internal/server/api" + "github.com/eduvpn/eduvpn-common/internal/server/base" + "github.com/eduvpn/eduvpn-common/internal/server/profile" "github.com/eduvpn/eduvpn-common/internal/wireguard" + "github.com/eduvpn/eduvpn-common/types/protocol" + srvtypes "github.com/eduvpn/eduvpn-common/types/server" "github.com/go-errors/errors" ) -type Type int8 - -const ( - CustomServerType Type = iota - InstituteAccessServerType - SecureInternetServerType -) - type Server interface { OAuth() *oauth.OAuth @@ -25,27 +23,13 @@ type Server interface { TemplateAuth() func(string) string // Base returns the server base - Base() (*Base, error) + Base() (*base.Base, error) - // RefreshEndpoints - RefreshEndpoints(*discovery.Discovery) error -} + // NeedsLocation checks if the server needs a secure internet location + NeedsLocation() bool -type EndpointList struct { - API string `json:"api_endpoint"` - Authorization string `json:"authorization_endpoint"` - Token string `json:"token_endpoint"` -} - -type EndpointsVersions struct { - V2 EndpointList `json:"http://eduvpn.org/api#2"` - V3 EndpointList `json:"http://eduvpn.org/api#3"` -} - -// Endpoints defines the json format for /.well-known/vpn-user-portal". -type Endpoints struct { - API EndpointsVersions `json:"api"` - V string `json:"v"` + // Public returns the representation that will be passed over the CGO barrier + Public() (interface{}, error) } func UpdateTokens(srv Server, t oauth.Token) { @@ -56,12 +40,12 @@ func OAuthURL(srv Server, name string) (string, error) { return srv.OAuth().AuthURL(name, srv.TemplateAuth()) } -func OAuthExchange(srv Server) error { - return srv.OAuth().Exchange() +func OAuthExchange(ctx context.Context, srv Server) error { + return srv.OAuth().Exchange(ctx) } -func HeaderToken(srv Server) (string, error) { - return srv.OAuth().AccessToken() +func HeaderToken(ctx context.Context, srv Server) (string, error) { + return srv.OAuth().AccessToken(ctx) } func MarkTokenExpired(srv Server) { @@ -72,16 +56,13 @@ func MarkTokensForRenew(srv Server) { srv.OAuth().SetTokenRenew() } -func NeedsRelogin(srv Server) bool { - _, err := HeaderToken(srv) +func NeedsRelogin(ctx context.Context, srv Server) bool { + // TODO: this error can be a context cancel + _, err := HeaderToken(ctx, srv) return err != nil } -func CancelOAuth(srv Server) { - srv.OAuth().Cancel() -} - -func CurrentProfile(srv Server) (*Profile, error) { +func CurrentProfile(srv Server) (*profile.Profile, error) { b, err := srv.Base() if err != nil { return nil, err @@ -96,19 +77,31 @@ func CurrentProfile(srv Server) (*Profile, error) { return nil, errors.Errorf("profile not found: " + pID) } -func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) { +func ValidProfiles(srv Server, wireguardSupport bool) (*[]profile.Profile, error) { // No error wrapping here otherwise we wrap it too much b, err := srv.Base() if err != nil { return nil, err } - ps := b.ValidProfiles(wireguardSupport) - if len(ps.Info.ProfileList) == 0 { + ps := b.Profiles.Supported(wireguardSupport) + if len(ps) == 0 { return nil, errors.Errorf("no profiles found with supported protocols") } return &ps, nil } +func Profile(srv Server, id string) error { + b, err := srv.Base() + if err != nil { + return err + } + if !b.Profiles.Has(id) { + return errors.Errorf("no profile available with id: %s", id) + } + b.Profiles.Current = id + return nil +} + type ConfigData struct { // The configuration Config string @@ -120,7 +113,18 @@ type ConfigData struct { Tokens oauth.Token } -func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) { +// Public gets the public data from the types package +// dg specifies if this config is default gateway +func (c *ConfigData) Public(dg bool) srvtypes.Configuration { + return srvtypes.Configuration{ + VPNConfig: c.Config, + Protocol: protocol.New(c.Type), + DefaultGateway: dg, + Tokens: c.Tokens.Public(), + } +} + +func wireguardGetConfig(ctx context.Context, srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) { b, err := srv.Base() if err != nil { return nil, err @@ -133,7 +137,7 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi } pub := key.PublicKey().String() - cfg, proto, exp, err := APIConnectWireguard(srv, pID, pub, preferTCP, openVPNSupport) + cfg, proto, exp, err := api.ConnectWireguard(ctx, b, srv.OAuth(), pID, pub, preferTCP, openVPNSupport) if err != nil { return nil, err } @@ -159,13 +163,13 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*Confi return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil } -func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) { +func openVPNGetConfig(ctx context.Context, srv Server, preferTCP bool) (*ConfigData, error) { b, err := srv.Base() if err != nil { return nil, err } pid := b.Profiles.Current - cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP) + cfg, exp, err := api.ConnectOpenVPN(ctx, b, srv.OAuth(), pid, preferTCP) if err != nil { return nil, err } @@ -184,15 +188,14 @@ func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) { return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil } -func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { - // Get new profiles using the info call - // This does not override the current profile - err := APIInfo(srv) +func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bool, error) { + b, err := srv.Base() if err != nil { return false, err } - - b, err := srv.Base() + // Get new profiles using the info call + // This does not override the current profile + err = api.Info(ctx, b, srv.OAuth()) if err != nil { return false, err } @@ -225,7 +228,18 @@ func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { return true, nil } -func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) { +func RefreshEndpoints(ctx context.Context, srv Server) error { + // Re-initialize the endpoints + // TODO: Make this a warning instead? + b, err := srv.Base() + if err != nil { + return err + } + + return api.Endpoints(ctx, b) +} + +func Config(ctx context.Context, server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) { p, err := CurrentProfile(server) if err != nil { return nil, err @@ -250,10 +264,10 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, case wg: // A wireguard connect call needs to generate a wireguard key and add it to the config // Also the server could send back an OpenVPN config if it supports OpenVPN - cfg, err = wireguardGetConfig(server, preferTCP, ovpn) + cfg, err = wireguardGetConfig(ctx, server, preferTCP, ovpn) // The config only supports OpenVPN case ovpn: - cfg, err = openVPNGetConfig(server, preferTCP) + cfg, err = openVPNGetConfig(ctx, server, preferTCP) // The config supports no available protocol because the profile only supports WireGuard but the client doesn't default: return nil, errors.New("no supported protocol found") @@ -267,6 +281,10 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, return cfg, err } -func Disconnect(server Server) error { - return APIDisconnect(server) +func Disconnect(ctx context.Context, server Server) error { + b, err := server.Base() + if err != nil { + return err + } + return api.Disconnect(ctx, b, server.OAuth()) } diff --git a/internal/server/servers.go b/internal/server/servers.go deleted file mode 100644 index 60c993d..0000000 --- a/internal/server/servers.go +++ /dev/null @@ -1,121 +0,0 @@ -package server - -import ( - discotypes "github.com/eduvpn/eduvpn-common/types/discovery" - "github.com/go-errors/errors" -) - -// TODO: Have a dedicated type for custom servers -type Servers struct { - // A custom server is just an institute access server under the hood - CustomServers InstituteAccessServers `json:"custom_servers"` - InstituteServers InstituteAccessServers `json:"institute_servers"` - SecureInternetHomeServer SecureInternetHomeServer `json:"secure_internet_home"` - IsType Type `json:"is_secure_internet"` -} - -// HasSecureInternet returns whether or not we have a secure internet server added -func (ss *Servers) HasSecureInternet() bool { - return len(ss.SecureInternetHomeServer.BaseMap) > 0 -} - -func (ss *Servers) AddSecureInternet( - secureOrg *discotypes.Organization, - secureServer *discotypes.Server, -) (Server, error) { - // If we have specified an organization ID - // We also need to get an authorization template - err := ss.SecureInternetHomeServer.init(secureOrg, secureServer) - if err != nil { - return nil, err - } - - ss.IsType = SecureInternetServerType - return &ss.SecureInternetHomeServer, nil -} - -func (ss *Servers) GetCurrentServer() (Server, error) { - // TODO(jwijenbergh): Almost certainly the return type should be pointer (*Server) - if ss.IsType == SecureInternetServerType { - if !ss.HasSecureLocation() { - return nil, errors.Errorf("ss.IsType = %v; ss.HasSecureLocation() = false", ss.IsType) - } - return &ss.SecureInternetHomeServer, nil - } - - srvs := &ss.InstituteServers - - if ss.IsType == CustomServerType { - srvs = &ss.CustomServers - } - if srvs.Map == nil { - return nil, errors.Errorf("srvs.Map is nil") - } - - srv, ok := srvs.Map[srvs.CurrentURL] - if !ok || srv == nil { - return nil, errors.Errorf("server not found") - } - return srv, nil -} - -func (ss *Servers) addInstituteAndCustom( - discoServer *discotypes.Server, - isCustom bool, -) (Server, error) { - URL := discoServer.BaseURL - srvs := &ss.InstituteServers - srvType := InstituteAccessServerType - - if isCustom { - srvs = &ss.CustomServers - srvType = CustomServerType - } - - if srvs.Map == nil { - srvs.Map = make(map[string]*InstituteAccessServer) - } - - srv, ok := srvs.Map[URL] - - // initialize the server if it doesn't exist yet - if !ok { - srv = &InstituteAccessServer{} - } - - if err := srv.init(URL, discoServer.DisplayName, discoServer.Type, discoServer.SupportContact); err != nil { - return nil, err - } - srvs.Map[URL] = srv - ss.IsType = srvType - return srv, nil -} - -func (ss *Servers) AddInstituteAccessServer( - instituteServer *discotypes.Server, -) (Server, error) { - return ss.addInstituteAndCustom(instituteServer, false) -} - -func (ss *Servers) AddCustomServer( - customServer *discotypes.Server, -) (Server, error) { - return ss.addInstituteAndCustom(customServer, true) -} - -func (ss *Servers) GetSecureLocation() string { - return ss.SecureInternetHomeServer.CurrentLocation -} - -func (ss *Servers) SetSecureLocation( - chosenLocationServer *discotypes.Server, -) error { - // Make sure to add the current location - - if _, err := ss.SecureInternetHomeServer.addLocation(chosenLocationServer); err != nil { - return err - } - - ss.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode - return nil -} -- cgit v1.2.3