summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/api.go214
-rw-r--r--internal/server/base.go29
-rw-r--r--internal/server/custom.go41
-rw-r--r--internal/server/instituteaccess.go85
-rw-r--r--internal/server/secureinternet.go151
-rw-r--r--internal/server/server.go318
-rw-r--r--internal/server/servers.go117
7 files changed, 380 insertions, 575 deletions
diff --git a/internal/server/api.go b/internal/server/api.go
index 21ba6f4..dfa8e14 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -2,7 +2,6 @@ package server
import (
"encoding/json"
- "errors"
"fmt"
"net/http"
"net/url"
@@ -10,130 +9,114 @@ import (
"time"
httpw "github.com/eduvpn/eduvpn-common/internal/http"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
func APIGetEndpoints(baseURL string) (*Endpoints, error) {
- errorMessage := "failed getting server endpoints"
- url, urlErr := url.Parse(baseURL)
- if urlErr != nil {
- return nil, types.NewWrappedError(errorMessage, urlErr)
+ u, err := url.Parse(baseURL)
+ if err != nil {
+ return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
- wellKnownPath := "/.well-known/vpn-user-portal"
+ wk := "/.well-known/vpn-user-portal"
- url.Path = path.Join(url.Path, wellKnownPath)
- _, body, bodyErr := httpw.Get(url.String())
-
- if bodyErr != nil {
- return nil, types.NewWrappedError(errorMessage, bodyErr)
+ u.Path = path.Join(u.Path, wk)
+ _, body, err := httpw.Get(u.String())
+ if err != nil {
+ return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
- endpoints := &Endpoints{}
- jsonErr := json.Unmarshal(body, endpoints)
-
- if jsonErr != nil {
- return nil, types.NewWrappedError(errorMessage, jsonErr)
+ ep := &Endpoints{}
+ if err = json.Unmarshal(body, ep); err != nil {
+ return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0)
}
- return endpoints, nil
+ return ep, nil
}
func apiAuthorized(
- server Server,
+ srv Server,
method string,
endpoint string,
opts *httpw.OptionalParams,
) (http.Header, []byte, error) {
- errorMessage := "failed API authorized"
// Ensure optional is not nil as we will fill it with headers
if opts == nil {
opts = &httpw.OptionalParams{}
}
- base, baseErr := server.Base()
-
- if baseErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, baseErr)
+ b, err := srv.Base()
+ if err != nil {
+ return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0)
}
// Join the paths
- url, urlErr := url.Parse(base.Endpoints.API.V3.API)
- if urlErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, urlErr)
+ u, err := url.Parse(b.Endpoints.API.V3.API)
+ if err != nil {
+ return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0)
}
- url.Path = path.Join(url.Path, endpoint)
+ u.Path = path.Join(u.Path, endpoint)
// Make sure the tokens are valid, this will return an error if re-login is needed
- token, tokenErr := HeaderToken(server)
- if tokenErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, tokenErr)
+ t, err := HeaderToken(srv)
+ if err != nil {
+ return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0)
}
- headerKey := "Authorization"
- headerValue := fmt.Sprintf("Bearer %s", token)
+ key := "Authorization"
+ val := fmt.Sprintf("Bearer %s", t)
if opts.Headers != nil {
- opts.Headers.Add(headerKey, headerValue)
+ opts.Headers.Add(key, val)
} else {
- opts.Headers = http.Header{headerKey: {headerValue}}
+ opts.Headers = http.Header{key: {val}}
}
- return httpw.MethodWithOpts(method, url.String(), opts)
+ return httpw.MethodWithOpts(method, u.String(), opts)
}
func apiAuthorizedRetry(
- server Server,
+ srv Server,
method string,
endpoint string,
opts *httpw.OptionalParams,
) (http.Header, []byte, error) {
- errorMessage := "failed authorized API retry"
- header, body, bodyErr := apiAuthorized(server, method, endpoint, opts)
-
- if bodyErr != nil {
- var error *httpw.StatusError
-
- // Only retry authorized if we get a HTTP 401
- if errors.As(bodyErr, &error) && error.Status == 401 {
- // Mark the token as expired and retry so we trigger the refresh flow
- MarkTokenExpired(server)
- retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts)
- if retryErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, retryErr)
- }
- return retryHeader, retryBody, nil
- }
- return nil, nil, types.NewWrappedError(errorMessage, bodyErr)
- }
- return header, body, nil
-}
-
-func APIInfo(server Server) error {
- errorMessage := "failed API /info"
- _, body, bodyErr := apiAuthorizedRetry(server, http.MethodGet, "/info", nil)
- if bodyErr != nil {
- return types.NewWrappedError(errorMessage, bodyErr)
+ h, body, err := apiAuthorized(srv, method, endpoint, opts)
+ if err == nil {
+ return h, body, nil
}
- structure := ProfileInfo{}
- jsonErr := json.Unmarshal(body, &structure)
- if jsonErr != nil {
- return types.NewWrappedError(errorMessage, jsonErr)
+ statErr := &httpw.StatusError{}
+ // Only retry authorized if we get an HTTP 401
+ if errors.As(err, &statErr) && statErr.Status == 401 {
+ // Mark the token as expired and retry, so we trigger the refresh flow
+ MarkTokenExpired(srv)
+ h, body, err = apiAuthorized(srv, method, endpoint, opts)
}
+ return h, body, err
+}
- base, baseErr := server.Base()
+func APIInfo(srv Server) error {
+ _, body, err := apiAuthorizedRetry(srv, http.MethodGet, "/info", nil)
+ if err != nil {
+ return err
+ }
+ pi := ProfileInfo{}
+ if err = json.Unmarshal(body, &pi); err != nil {
+ return errors.WrapPrefix(err, "failed API /info", 0)
+ }
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
+ b, err := srv.Base()
+ if err != nil {
+ return err
}
// Store the profiles and make sure that the current profile is not overwritten
- previousProfile := base.Profiles.Current
- base.Profiles = structure
- base.Profiles.Current = previousProfile
+ prev := b.Profiles.Current
+ b.Profiles = pi
+ b.Profiles.Current = prev
return nil
}
// see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1
-func GetPreferTCPString(preferTCP bool) string {
+func boolToYesNo(preferTCP bool) string {
if preferTCP {
return "yes"
}
@@ -141,88 +124,77 @@ func GetPreferTCPString(preferTCP bool) string {
}
func APIConnectWireguard(
- server Server,
+ srv Server,
profileID string,
pubkey string,
preferTCP bool,
- supportsOpenVPN bool,
+ openVPNSupport bool,
) (string, string, time.Time, error) {
- errorMessage := "failed obtaining a WireGuard configuration"
- headers := http.Header{
+ hdrs := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-wireguard-profile"},
}
// This profile also supports OpenVPN
// Indicate that we also accept OpenVPN profiles
- if supportsOpenVPN {
- headers.Add("accept", "application/x-openvpn-profile")
+ if openVPNSupport {
+ hdrs.Add("accept", "application/x-openvpn-profile")
}
- urlForm := url.Values{
+ vals := url.Values{
"profile_id": {profileID},
"public_key": {pubkey},
- "prefer_tcp": {GetPreferTCPString(preferTCP)},
+ "prefer_tcp": {boolToYesNo(preferTCP)},
}
- header, connectBody, connectErr := apiAuthorizedRetry(
- server,
- http.MethodPost,
- "/connect",
- &httpw.OptionalParams{Headers: headers, Body: urlForm},
- )
- if connectErr != nil {
- return "", "", time.Time{}, types.NewWrappedError(
- errorMessage,
- connectErr,
- )
+ h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect",
+ &httpw.OptionalParams{Headers: hdrs, Body: vals})
+ if err != nil {
+ return "", "", time.Time{}, err
}
- expires := header.Get("expires")
+ exp := h.Get("expires")
- pTime, pTimeErr := http.ParseTime(expires)
- if pTimeErr != nil {
- return "", "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr)
+ ptm, err := http.ParseTime(exp)
+ if err != nil {
+ return "", "", time.Time{}, errors.WrapPrefix(err, "failed obtaining a WireGuard configuration", 0)
}
- contentType := header.Get("content-type")
-
- content := "openvpn"
- if contentType == "application/x-wireguard-profile" {
- content = "wireguard"
+ ct := h.Get("content-type")
+ c := "openvpn"
+ if ct == "application/x-wireguard-profile" {
+ c = "wireguard"
}
- return string(connectBody), content, pTime, nil
+
+ return string(body), c, ptm, nil
}
-func APIConnectOpenVPN(server Server, profileID string, preferTCP bool) (string, time.Time, error) {
- errorMessage := "failed obtaining an OpenVPN configuration"
- headers := http.Header{
+func APIConnectOpenVPN(srv Server, profileID string, preferTCP bool) (string, time.Time, error) {
+ hdrs := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-openvpn-profile"},
}
- urlForm := url.Values{
+ vals := url.Values{
"profile_id": {profileID},
- "prefer_tcp": {GetPreferTCPString(preferTCP)},
+ "prefer_tcp": {boolToYesNo(preferTCP)},
}
- header, connectBody, connectErr := apiAuthorizedRetry(
- server,
- http.MethodPost,
- "/connect",
- &httpw.OptionalParams{Headers: headers, Body: urlForm},
- )
- if connectErr != nil {
- return "", time.Time{}, types.NewWrappedError(errorMessage, connectErr)
+ h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect",
+ &httpw.OptionalParams{Headers: hdrs, Body: vals})
+ if err != nil {
+ return "", time.Time{}, err
}
- expires := header.Get("expires")
- pTime, pTimeErr := http.ParseTime(expires)
- if pTimeErr != nil {
- return "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr)
+ exp := h.Get("expires")
+ ptm, err := http.ParseTime(exp)
+ if err != nil {
+ return "", time.Time{}, errors.WrapPrefix(err, "failed obtaining an OpenVPN configuration", 0)
}
- return string(connectBody), pTime, nil
+
+ return string(body), ptm, nil
}
+// APIDisconnect disconnects from the API.
// This needs no further return value as it's best effort.
func APIDisconnect(server Server) {
_, _, _ = apiAuthorized(server, http.MethodPost, "/disconnect", nil)
diff --git a/internal/server/base.go b/internal/server/base.go
index bb88eb3..81049cf 100644
--- a/internal/server/base.go
+++ b/internal/server/base.go
@@ -2,11 +2,9 @@ package server
import (
"time"
-
- "github.com/eduvpn/eduvpn-common/types"
)
-// The base type for servers.
+// Base is the base type for servers.
type Base struct {
URL string `json:"base_url"`
DisplayName map[string]string `json:"display_name"`
@@ -18,28 +16,27 @@ type Base struct {
Type string `json:"server_type"`
}
-func (base *Base) InitializeEndpoints() error {
- errorMessage := "failed initializing endpoints"
- endpoints, endpointsErr := APIGetEndpoints(base.URL)
- if endpointsErr != nil {
- return types.NewWrappedError(errorMessage, endpointsErr)
+func (b *Base) InitializeEndpoints() error {
+ ep, err := APIGetEndpoints(b.URL)
+ if err != nil {
+ return err
}
- base.Endpoints = *endpoints
+ b.Endpoints = *ep
return nil
}
-func (base *Base) ValidProfiles(clientSupportsWireguard bool) ProfileInfo {
- var validProfiles []Profile
- for _, profile := range base.Profiles.Info.ProfileList {
+func (b *Base) ValidProfiles(wireguardSupport bool) ProfileInfo {
+ var vps []Profile
+ for _, p := range b.Profiles.Info.ProfileList {
// Not a valid profile because it does not support openvpn
// Also the client does not support wireguard
- if !profile.supportsOpenVPN() && !clientSupportsWireguard {
+ if !p.supportsOpenVPN() && !wireguardSupport {
continue
}
- validProfiles = append(validProfiles, profile)
+ vps = append(vps, p)
}
return ProfileInfo{
- Current: base.Profiles.Current,
- Info: ProfileListInfo{ProfileList: validProfiles},
+ Current: b.Profiles.Current,
+ Info: ProfileListInfo{ProfileList: vps},
}
}
diff --git a/internal/server/custom.go b/internal/server/custom.go
index d376727..bf0b230 100644
--- a/internal/server/custom.go
+++ b/internal/server/custom.go
@@ -1,42 +1,35 @@
package server
import (
- "errors"
- "fmt"
-
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
-func (servers *Servers) SetCustomServer(server Server) error {
- errorMessage := "failed setting custom server"
- base, baseErr := server.Base()
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
+func (ss *Servers) SetCustomServer(server Server) error {
+ b, err := server.Base()
+ if err != nil {
+ return err
}
- if base.Type != "custom_server" {
- return types.NewWrappedError(errorMessage, errors.New("not a custom server"))
+ if b.Type != "custom_server" {
+ return errors.WrapPrefix(err, "not a custom server", 0)
}
- if _, ok := servers.CustomServers.Map[base.URL]; ok {
- servers.CustomServers.CurrentURL = base.URL
- servers.IsType = CustomServerType
+ if _, ok := ss.CustomServers.Map[b.URL]; ok {
+ ss.CustomServers.CurrentURL = b.URL
+ ss.IsType = CustomServerType
} else {
- return types.NewWrappedError(errorMessage, errors.New("not a custom server"))
+ return errors.Errorf("not a custom server")
}
return nil
}
-func (servers *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) {
- if server, ok := servers.CustomServers.Map[url]; ok {
- return server, nil
+func (ss *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) {
+ if srv, ok := ss.CustomServers.Map[url]; ok {
+ return srv, nil
}
- return nil, types.NewWrappedError(
- "failed to get institute access server",
- fmt.Errorf("no custom server with URL: %s", url),
- )
+ return nil, errors.Errorf("failed to get institute access server - no custom server with URL '%s'", url)
}
-func (servers *Servers) RemoveCustomServer(url string) {
- servers.CustomServers.Remove(url)
+func (ss *Servers) RemoveCustomServer(url string) {
+ ss.CustomServers.Remove(url)
}
diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go
index 9b6f735..56ed1cf 100644
--- a/internal/server/instituteaccess.go
+++ b/internal/server/instituteaccess.go
@@ -1,14 +1,10 @@
package server
import (
- "errors"
- "fmt"
-
"github.com/eduvpn/eduvpn-common/internal/oauth"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
-// An instute access server.
type InstituteAccessServer struct {
// An instute access server has its own OAuth
Auth oauth.OAuth `json:"oauth"`
@@ -22,80 +18,75 @@ type InstituteAccessServers struct {
CurrentURL string `json:"current_url"`
}
-func (servers *Servers) SetInstituteAccess(server Server) error {
- errorMessage := "failed setting institute access server"
- base, baseErr := server.Base()
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
+func (ss *Servers) SetInstituteAccess(srv Server) error {
+ b, err := srv.Base()
+ if err != nil {
+ return err
}
- if base.Type != "institute_access" {
- return types.NewWrappedError(errorMessage, errors.New("not an institute access server"))
+ if b.Type != "institute_access" {
+ return errors.Errorf("not an institute access server")
}
- if _, ok := servers.InstituteServers.Map[base.URL]; ok {
- servers.InstituteServers.CurrentURL = base.URL
- servers.IsType = InstituteAccessServerType
+ if _, ok := ss.InstituteServers.Map[b.URL]; ok {
+ ss.InstituteServers.CurrentURL = b.URL
+ ss.IsType = InstituteAccessServerType
} else {
- return types.NewWrappedError(errorMessage, errors.New("no such institute access server"))
+ return errors.Errorf("no such institute access server")
}
return nil
}
-func (servers *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) {
- if server, ok := servers.InstituteServers.Map[url]; ok {
- return server, nil
+func (ss *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) {
+ if srv, ok := ss.InstituteServers.Map[url]; ok {
+ return srv, nil
}
- return nil, types.NewWrappedError(
- "failed to get institute access server",
- fmt.Errorf("no institute access server with URL: %s", url),
- )
+ return nil, errors.Errorf("no institute access server with URL: %s", url)
}
-func (servers *Servers) RemoveInstituteAccess(url string) {
- servers.InstituteServers.Remove(url)
+func (ss *Servers) RemoveInstituteAccess(url string) {
+ ss.InstituteServers.Remove(url)
}
-func (servers *InstituteAccessServers) Remove(url string) {
+func (iass *InstituteAccessServers) Remove(url string) {
// Reset the current url
- if servers.CurrentURL == url {
- servers.CurrentURL = ""
+ if iass.CurrentURL == url {
+ iass.CurrentURL = ""
}
// Delete the url from the map
- delete(servers.Map, url)
+ delete(iass.Map, url)
}
-func (institute *InstituteAccessServer) TemplateAuth() func(string) string {
+func (ias *InstituteAccessServer) TemplateAuth() func(string) string {
return func(authURL string) string {
return authURL
}
}
-func (institute *InstituteAccessServer) Base() (*Base, error) {
- return &institute.Basic, nil
+func (ias *InstituteAccessServer) Base() (*Base, error) {
+ return &ias.Basic, nil
}
-func (institute *InstituteAccessServer) OAuth() *oauth.OAuth {
- return &institute.Auth
+func (ias *InstituteAccessServer) OAuth() *oauth.OAuth {
+ return &ias.Auth
}
-func (institute *InstituteAccessServer) init(
+func (ias *InstituteAccessServer) init(
url string,
- displayName map[string]string,
- serverType string,
+ name map[string]string,
+ srvType string,
supportContact []string,
) error {
- errorMessage := fmt.Sprintf("failed initializing server %s", url)
- institute.Basic.URL = url
- institute.Basic.DisplayName = displayName
- institute.Basic.SupportContact = supportContact
- institute.Basic.Type = serverType
- endpointsErr := institute.Basic.InitializeEndpoints()
- if endpointsErr != nil {
- return types.NewWrappedError(errorMessage, endpointsErr)
+ ias.Basic.URL = url
+ ias.Basic.DisplayName = name
+ ias.Basic.SupportContact = supportContact
+ ias.Basic.Type = srvType
+ err := ias.Basic.InitializeEndpoints()
+ if err != nil {
+ return err
}
- API := institute.Basic.Endpoints.API.V3
- institute.Auth.Init(url, API.Authorization, API.Token)
+ API := ias.Basic.Endpoints.API.V3
+ ias.Auth.Init(url, API.Authorization, API.Token)
return nil
}
diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go
index 998390d..12263a6 100644
--- a/internal/server/secureinternet.go
+++ b/internal/server/secureinternet.go
@@ -1,15 +1,13 @@
package server
import (
- "errors"
- "fmt"
-
"github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/util"
"github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
-// A secure internet server which has its own OAuth tokens
+// SecureInternetHomeServer secure internet server which has its own OAuth tokens
// It specifies the current location url it is connected to.
type SecureInternetHomeServer struct {
Auth oauth.OAuth `json:"oauth"`
@@ -24,150 +22,111 @@ type SecureInternetHomeServer struct {
CurrentLocation string `json:"current_location"`
}
-func (servers *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) {
- if !servers.HasSecureLocation() {
- return nil, errors.New("no secure internet home server")
+func (ss *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) {
+ if !ss.HasSecureLocation() {
+ return nil, errors.Errorf("no secure internet home server")
}
- return &servers.SecureInternetHomeServer, nil
+ return &ss.SecureInternetHomeServer, nil
}
-func (servers *Servers) SetSecureInternet(server Server) error {
- errorMessage := "failed setting secure internet server"
- base, baseErr := server.Base()
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
+func (ss *Servers) SetSecureInternet(server Server) error {
+ b, err := server.Base()
+ if err != nil {
+ return err
}
- if base.Type != "secure_internet" {
- return types.NewWrappedError(errorMessage, errors.New("not a secure internet server"))
+ if b.Type != "secure_internet" {
+ return errors.Errorf("not a secure internet server")
}
// The location should already be configured
// TODO: check for location?
- servers.IsType = SecureInternetServerType
+ ss.IsType = SecureInternetServerType
return nil
}
-func (servers *Servers) RemoveSecureInternet() {
+func (ss *Servers) RemoveSecureInternet() {
// Empty out the struct
- servers.SecureInternetHomeServer = SecureInternetHomeServer{}
+ ss.SecureInternetHomeServer = SecureInternetHomeServer{}
// If the current server is secure internet, default to custom server
- if servers.IsType == SecureInternetServerType {
- servers.IsType = CustomServerType
+ if ss.IsType == SecureInternetServerType {
+ ss.IsType = CustomServerType
}
}
-func (server *SecureInternetHomeServer) TemplateAuth() func(string) string {
+func (s *SecureInternetHomeServer) TemplateAuth() func(string) string {
return func(authURL string) string {
- return util.ReplaceWAYF(server.AuthorizationTemplate, authURL, server.HomeOrganizationID)
+ return util.ReplaceWAYF(s.AuthorizationTemplate, authURL, s.HomeOrganizationID)
}
}
-func (server *SecureInternetHomeServer) Base() (*Base, error) {
- errorMessage := "failed getting current secure internet home base"
- if server.BaseMap == nil {
- return nil, types.NewWrappedError(
- errorMessage,
- &SecureInternetMapNotFoundError{},
- )
+func (s *SecureInternetHomeServer) Base() (*Base, error) {
+ if s.BaseMap == nil {
+ return nil, errors.Errorf("secure internet map not found")
}
- base, exists := server.BaseMap[server.CurrentLocation]
-
- if !exists {
- return nil, types.NewWrappedError(
- errorMessage,
- &SecureInternetBaseNotFoundError{Current: server.CurrentLocation},
- )
+ b, ok := s.BaseMap[s.CurrentLocation]
+ if !ok {
+ return nil, errors.Errorf("secure internet base with location '%s' not found", s.CurrentLocation)
}
- return base, nil
+ return b, nil
}
-func (server *SecureInternetHomeServer) OAuth() *oauth.OAuth {
- return &server.Auth
+func (s *SecureInternetHomeServer) OAuth() *oauth.OAuth {
+ return &s.Auth
}
-func (servers *Servers) HasSecureLocation() bool {
- return servers.SecureInternetHomeServer.CurrentLocation != ""
+func (ss *Servers) HasSecureLocation() bool {
+ return ss.SecureInternetHomeServer.CurrentLocation != ""
}
-func (server *SecureInternetHomeServer) addLocation(
- locationServer *types.DiscoveryServer,
-) (*Base, error) {
- errorMessage := "failed adding a location"
+func (s *SecureInternetHomeServer) addLocation(locSrv *types.DiscoveryServer) (*Base, error) {
// Initialize the base map if it is non-nil
- if server.BaseMap == nil {
- server.BaseMap = make(map[string]*Base)
+ if s.BaseMap == nil {
+ s.BaseMap = make(map[string]*Base)
}
// Add the location to the base map
- base, exists := server.BaseMap[locationServer.CountryCode]
-
- if !exists || base == nil {
+ b, ok := s.BaseMap[locSrv.CountryCode]
+ if !ok || b == nil {
// Create the base to be added to the map
- base = &Base{}
- base.URL = locationServer.BaseURL
- base.DisplayName = server.DisplayName
- base.SupportContact = locationServer.SupportContact
- base.Type = "secure_internet"
- endpointsErr := base.InitializeEndpoints()
- if endpointsErr != nil {
- return nil, types.NewWrappedError(errorMessage, endpointsErr)
+ b = &Base{}
+ b.URL = locSrv.BaseURL
+ b.DisplayName = s.DisplayName
+ b.SupportContact = locSrv.SupportContact
+ b.Type = "secure_internet"
+ if err := b.InitializeEndpoints(); err != nil {
+ return nil, err
}
}
// Ensure it is in the map
- server.BaseMap[locationServer.CountryCode] = base
- return base, nil
+ s.BaseMap[locSrv.CountryCode] = b
+ return b, nil
}
// Initializes the home server and adds its own location.
-func (server *SecureInternetHomeServer) init(
- homeOrg *types.DiscoveryOrganization,
- homeLocation *types.DiscoveryServer,
-) error {
- errorMessage := "failed initializing secure internet home server"
-
- if server.HomeOrganizationID != homeOrg.OrgID {
+func (s *SecureInternetHomeServer) init(
+ homeOrg *types.DiscoveryOrganization, homeLoc *types.DiscoveryServer) error {
+ if s.HomeOrganizationID != homeOrg.OrgID {
// New home organisation, clear everything
- *server = SecureInternetHomeServer{}
+ *s = SecureInternetHomeServer{}
}
// Make sure to set the organization ID
- server.HomeOrganizationID = homeOrg.OrgID
- server.DisplayName = homeOrg.DisplayName
+ s.HomeOrganizationID = homeOrg.OrgID
+ s.DisplayName = homeOrg.DisplayName
// Make sure to set the authorization URL template
- server.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate
+ s.AuthorizationTemplate = homeLoc.AuthenticationURLTemplate
- base, baseErr := server.addLocation(homeLocation)
-
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
+ b, err := s.addLocation(homeLoc)
+ if err != nil {
+ return err
}
// Make sure oauth contains our endpoints
- server.Auth.Init(base.URL, base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token)
+ s.Auth.Init(b.URL, b.Endpoints.API.V3.Authorization, b.Endpoints.API.V3.Token)
return nil
}
-
-type SecureInternetHomeNotFoundError struct{}
-
-func (e *SecureInternetHomeNotFoundError) Error() string {
- return "failed to get secure internet home server, not found"
-}
-
-type SecureInternetMapNotFoundError struct{}
-
-func (e *SecureInternetMapNotFoundError) Error() string {
- return "secure internet map not found"
-}
-
-type SecureInternetBaseNotFoundError struct {
- Current string
-}
-
-func (e *SecureInternetBaseNotFoundError) Error() string {
- return fmt.Sprintf("secure internet base not found with current location: %s", e.Current)
-}
diff --git a/internal/server/server.go b/internal/server/server.go
index 95244d5..de0fa9a 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -1,13 +1,11 @@
package server
import (
- "errors"
- "fmt"
"time"
"github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/wireguard"
- "github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
type Type int8
@@ -21,10 +19,10 @@ const (
type Server interface {
OAuth() *oauth.OAuth
- // Get the authorization URL template function
+ // TemplateAuth returns the authorization URL template function
TemplateAuth() func(string) string
- // Gets the server base
+ // Base returns the server base
Base() (*Base, error)
}
@@ -34,7 +32,7 @@ type EndpointList struct {
Token string `json:"token_endpoint"`
}
-// Struct that defines the json format for /.well-known/vpn-user-portal".
+// Endpoints defines the json format for /.well-known/vpn-user-portal".
type Endpoints struct {
API struct {
V2 EndpointList `json:"http://eduvpn.org/api#2"`
@@ -43,310 +41,226 @@ type Endpoints struct {
V string `json:"v"`
}
-func ShouldRenewButton(server Server) bool {
- base, baseErr := server.Base()
-
- if baseErr != nil {
+func ShouldRenewButton(srv Server) bool {
+ b, err := srv.Base()
+ if err != nil {
// FIXME: Log error here?
return false
}
// Get current time
- current := time.Now()
+ now := time.Now()
// Session is expired
- if !current.Before(base.EndTime) {
+ if !now.Before(b.EndTime) {
return true
}
// 30 minutes have not passed
- if !current.After(base.StartTime.Add(30 * time.Minute)) {
+ if !now.After(b.StartTime.Add(30 * time.Minute)) {
return false
}
// Session will not expire today
- if !current.Add(24 * time.Hour).After(base.EndTime) {
+ if !now.Add(24 * time.Hour).After(b.EndTime) {
return false
}
// Session duration is less than 24 hours but not 75% has passed
- duration := base.EndTime.Sub(base.StartTime)
- percentTime := base.StartTime.Add((duration / 4) * 3)
- if duration < time.Duration(24*time.Hour) && !current.After(percentTime) {
+ d := b.EndTime.Sub(b.StartTime)
+ pct := b.StartTime.Add((d / 4) * 3)
+ if d < 24*time.Hour && !now.After(pct) {
return false
}
return true
}
-func OAuthURL(server Server, name string) (string, error) {
- return server.OAuth().AuthURL(name, server.TemplateAuth())
+func OAuthURL(srv Server, name string) (string, error) {
+ return srv.OAuth().AuthURL(name, srv.TemplateAuth())
}
-func OAuthExchange(server Server) error {
- return server.OAuth().Exchange()
+func OAuthExchange(srv Server) error {
+ return srv.OAuth().Exchange()
}
-func HeaderToken(server Server) (string, error) {
- token, tokenErr := server.OAuth().AccessToken()
- if tokenErr != nil {
- return "", types.NewWrappedError("failed getting server token for HTTP Header", tokenErr)
- }
- return token, nil
+func HeaderToken(srv Server) (string, error) {
+ return srv.OAuth().AccessToken()
}
-func MarkTokenExpired(server Server) {
- server.OAuth().SetTokenExpired()
+func MarkTokenExpired(srv Server) {
+ srv.OAuth().SetTokenExpired()
}
-func MarkTokensForRenew(server Server) {
- server.OAuth().SetTokenRenew()
+func MarkTokensForRenew(srv Server) {
+ srv.OAuth().SetTokenRenew()
}
-func NeedsRelogin(server Server) bool {
- _, tokenErr := HeaderToken(server)
- return tokenErr != nil
+func NeedsRelogin(srv Server) bool {
+ _, err := HeaderToken(srv)
+ return err != nil
}
-func CancelOAuth(server Server) {
- server.OAuth().Cancel()
+func CancelOAuth(srv Server) {
+ srv.OAuth().Cancel()
}
-func CurrentProfile(server Server) (*Profile, error) {
- errorMessage := "failed getting current profile"
- base, baseErr := server.Base()
-
- if baseErr != nil {
- return nil, types.NewWrappedError(errorMessage, baseErr)
+func CurrentProfile(srv Server) (*Profile, error) {
+ b, err := srv.Base()
+ if err != nil {
+ return nil, err
}
- profileID := base.Profiles.Current
- for _, profile := range base.Profiles.Info.ProfileList {
- if profile.ID == profileID {
+ pid := b.Profiles.Current
+ for _, profile := range b.Profiles.Info.ProfileList {
+ if profile.ID == pid {
return &profile, nil
}
}
- return nil, types.NewWrappedError(
- errorMessage,
- &CurrentProfileNotFoundError{ProfileID: profileID},
- )
+ return nil, errors.Errorf("profile not found: " + pid)
}
-func ValidProfiles(server Server, clientSupportsWireguard bool) (*ProfileInfo, error) {
- errorMessage := "failed to get valid profiles"
+func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) {
// No error wrapping here otherwise we wrap it too much
- base, baseErr := server.Base()
- if baseErr != nil {
- return nil, types.NewWrappedError(errorMessage, baseErr)
+ b, err := srv.Base()
+ if err != nil {
+ return nil, err
}
- profiles := base.ValidProfiles(clientSupportsWireguard)
- if len(profiles.Info.ProfileList) == 0 {
- return nil, types.NewWrappedError(
- errorMessage,
- errors.New("no profiles found with supported protocols"),
- )
+ ps := b.ValidProfiles(wireguardSupport)
+ if len(ps.Info.ProfileList) == 0 {
+ return nil, errors.Errorf("no profiles found with supported protocols")
}
- return &profiles, nil
+ return &ps, nil
}
-func wireguardGetConfig(
- server Server,
- preferTCP bool,
- supportsOpenVPN bool,
-) (string, string, error) {
- errorMessage := "failed getting server WireGuard configuration"
- base, baseErr := server.Base()
-
- if baseErr != nil {
- return "", "", types.NewWrappedError(errorMessage, baseErr)
+func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string, string, error) {
+ b, err := srv.Base()
+ if err != nil {
+ return "", "", err
}
- profileID := base.Profiles.Current
- wireguardKey, wireguardErr := wireguard.GenerateKey()
-
- if wireguardErr != nil {
- return "", "", types.NewWrappedError(errorMessage, wireguardErr)
+ pid := b.Profiles.Current
+ key, err := wireguard.GenerateKey()
+ if err != nil {
+ return "", "", err
}
- wireguardPublicKey := wireguardKey.PublicKey().String()
- config, content, expires, configErr := APIConnectWireguard(
- server,
- profileID,
- wireguardPublicKey,
- preferTCP,
- supportsOpenVPN,
- )
-
- if configErr != nil {
- return "", "", types.NewWrappedError(errorMessage, configErr)
+ pub := key.PublicKey().String()
+ cfg, ct, exp, err := APIConnectWireguard(srv, pid, pub, preferTCP, openVPNSupport)
+ if err != nil {
+ return "", "", err
}
// Store start and end time
- base.StartTime = time.Now()
- base.EndTime = expires
+ b.StartTime = time.Now()
+ b.EndTime = exp
- if content == "wireguard" {
+ if ct == "wireguard" {
// This needs the go code a way to identify a connection
// Use the uuid of the connection e.g. on Linux
// This needs the client code to call the go code
- config = wireguard.ConfigAddKey(config, wireguardKey)
+ cfg = wireguard.ConfigAddKey(cfg, key)
}
- return config, content, nil
+ return cfg, ct, nil
}
-func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) {
- errorMessage := "failed getting server OpenVPN configuration"
- base, baseErr := server.Base()
-
- if baseErr != nil {
- return "", "", types.NewWrappedError(errorMessage, baseErr)
+func openVPNGetConfig(srv Server, preferTCP bool) (string, string, error) {
+ b, err := srv.Base()
+ if err != nil {
+ return "", "", err
}
- profileID := base.Profiles.Current
- configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profileID, preferTCP)
+ pid := b.Profiles.Current
+ cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP)
// Store start and end time
- base.StartTime = time.Now()
- base.EndTime = expires
+ b.StartTime = time.Now()
+ b.EndTime = exp
- if configErr != nil {
- return "", "", types.NewWrappedError(errorMessage, configErr)
+ if err != nil {
+ return "", "", err
}
- return configOpenVPN, "openvpn", nil
+ return cfg, "openvpn", nil
}
-func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) {
- errorMessage := "failed has valid profile check"
-
+func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) {
// Get new profiles using the info call
// This does not override the current profile
- infoErr := APIInfo(server)
- if infoErr != nil {
- return false, types.NewWrappedError(errorMessage, infoErr)
+ err := APIInfo(srv)
+ if err != nil {
+ return false, err
}
- base, baseErr := server.Base()
- if baseErr != nil {
- return false, types.NewWrappedError(errorMessage, baseErr)
+ b, err := srv.Base()
+ if err != nil {
+ return false, err
}
// If there was a profile chosen and it doesn't exist anymore, reset it
- if base.Profiles.Current != "" {
- _, existsProfileErr := CurrentProfile(server)
- if existsProfileErr != nil {
- base.Profiles.Current = ""
+ if b.Profiles.Current != "" {
+ if _, err = CurrentProfile(srv); err != nil {
+ b.Profiles.Current = ""
}
}
- // Set the current profile if there is only one profile or profile is already selected
- if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" {
- // Set the first profile if none is selected
- if base.Profiles.Current == "" {
- base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID
- }
- profile, profileErr := CurrentProfile(server)
- // shouldn't happen
- if profileErr != nil {
- return false, types.NewWrappedError(errorMessage, profileErr)
- }
- // Profile does not support OpenVPN but the client also doesn't support WireGuard
- if !profile.supportsOpenVPN() && !clientSupportsWireguard {
- return false, nil
- }
- return true, nil
+ if len(b.Profiles.Info.ProfileList) != 1 && b.Profiles.Current == "" {
+ return false, nil
}
- return false, nil
+ // Set the current profile if there is only one profile or profile is already selected
+ // Set the first profile if none is selected
+ if b.Profiles.Current == "" {
+ b.Profiles.Current = b.Profiles.Info.ProfileList[0].ID
+ }
+ p, err := CurrentProfile(srv)
+ // shouldn't happen
+ if err != nil {
+ return false, err
+ }
+ // Profile does not support OpenVPN but the client also doesn't support WireGuard
+ if !p.supportsOpenVPN() && !wireguardSupport {
+ return false, nil
+ }
+ return true, nil
}
-func RefreshEndpoints(server Server) error {
- errorMessage := "failed to refresh server endpoints"
-
+func RefreshEndpoints(srv Server) error {
// Re-initialize the endpoints
// TODO: Make this a warning instead?
- base, baseErr := server.Base()
- if baseErr != nil {
- return types.NewWrappedError(errorMessage, baseErr)
- }
-
- endpointsErr := base.InitializeEndpoints()
- if endpointsErr != nil {
- return types.NewWrappedError(errorMessage, endpointsErr)
+ b, err := srv.Base()
+ if err != nil {
+ return err
}
- return nil
+ return b.InitializeEndpoints()
}
-func Config(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) {
- errorMessage := "failed getting an OpenVPN/WireGuard configuration"
-
- profile, profileErr := CurrentProfile(server)
- if profileErr != nil {
- return "", "", types.NewWrappedError(errorMessage, profileErr)
+func Config(server Server, wireguardSupport bool, preferTCP bool) (string, string, error) {
+ p, err := CurrentProfile(server)
+ if err != nil {
+ return "", "", err
}
- supportsOpenVPN := profile.supportsOpenVPN()
- supportsWireguard := profile.supportsWireguard() && clientSupportsWireguard
-
- var config string
- var configType string
- var configErr error
+ ovpn := p.supportsOpenVPN()
+ wg := p.supportsWireguard() && wireguardSupport
switch {
// The config supports wireguard and optionally openvpn
- case supportsWireguard:
+ case wg:
// A wireguard connect call needs to generate a wireguard key and add it to the config
// Also the server could send back an OpenVPN config if it supports OpenVPN
- config, configType, configErr = wireguardGetConfig(server, preferTCP, supportsOpenVPN)
+ return wireguardGetConfig(server, preferTCP, ovpn)
// The config only supports OpenVPN
- case supportsOpenVPN:
- config, configType, configErr = openVPNGetConfig(server, preferTCP)
+ case ovpn:
+ return openVPNGetConfig(server, preferTCP)
// The config supports no available protocol because the profile only supports WireGuard but the client doesn't
default:
- return "", "", types.NewWrappedError(errorMessage, errors.New("no supported protocol found"))
+ return "", "", errors.Errorf("no supported protocol found")
}
-
- if configErr != nil {
- return "", "", types.NewWrappedError(errorMessage, configErr)
- }
-
- return config, configType, nil
}
func Disconnect(server Server) {
APIDisconnect(server)
}
-
-type CurrentProfileNotFoundError struct {
- ProfileID string
-}
-
-func (e *CurrentProfileNotFoundError) Error() string {
- return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID)
-}
-
-type ConfigPreferTCPError struct{}
-
-func (e *ConfigPreferTCPError) Error() string {
- return "failed to get config, prefer TCP is on but the server does not support OpenVPN"
-}
-
-type EmptyURLError struct{}
-
-func (e *EmptyURLError) Error() string {
- return "failed ensuring server, empty url provided"
-}
-
-type CurrentNoMapError struct{}
-
-func (e *CurrentNoMapError) Error() string {
- return "failed getting current server, no servers available"
-}
-
-type CurrentNotFoundError struct{}
-
-func (e *CurrentNotFoundError) Error() string {
- return "failed getting current server, not found"
-}
diff --git a/internal/server/servers.go b/internal/server/servers.go
index a076770..b34dcff 100644
--- a/internal/server/servers.go
+++ b/internal/server/servers.go
@@ -1,9 +1,8 @@
package server
import (
- "fmt"
-
"github.com/eduvpn/eduvpn-common/types"
+ "github.com/go-errors/errors"
)
type Servers struct {
@@ -14,125 +13,105 @@ type Servers struct {
IsType Type `json:"is_secure_internet"`
}
-func (servers *Servers) AddSecureInternet(
+func (ss *Servers) AddSecureInternet(
secureOrg *types.DiscoveryOrganization,
secureServer *types.DiscoveryServer,
) (Server, error) {
- errorMessage := "failed adding secure internet server"
// If we have specified an organization ID
// We also need to get an authorization template
- initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer)
+ err := ss.SecureInternetHomeServer.init(secureOrg, secureServer)
- if initErr != nil {
- return nil, types.NewWrappedError(errorMessage, initErr)
+ if err != nil {
+ return nil, err
}
- servers.IsType = SecureInternetServerType
- return &servers.SecureInternetHomeServer, nil
+ ss.IsType = SecureInternetServerType
+ return &ss.SecureInternetHomeServer, nil
}
-func (servers *Servers) GetCurrentServer() (Server, error) {
- errorMessage := "failed getting current server"
- if servers.IsType == SecureInternetServerType {
- if !servers.HasSecureLocation() {
- return nil, types.NewWrappedError(
- errorMessage,
- &CurrentNotFoundError{},
- )
+func (ss *Servers) GetCurrentServer() (Server, error) {
+ //TODO(jwijenbergh): Almost certainly the return type should be pointer (*Server)
+ if ss.IsType == SecureInternetServerType {
+ if !ss.HasSecureLocation() {
+ return nil, errors.Errorf("ss.IsType = %v; ss.HasSecureLocation() = false", ss.IsType)
}
- return &servers.SecureInternetHomeServer, nil
+ return &ss.SecureInternetHomeServer, nil
}
- serversStruct := &servers.InstituteServers
+ srvs := &ss.InstituteServers
- if servers.IsType == CustomServerType {
- serversStruct = &servers.CustomServers
+ if ss.IsType == CustomServerType {
+ srvs = &ss.CustomServers
}
- currentServerURL := serversStruct.CurrentURL
- bases := serversStruct.Map
- if bases == nil {
- return nil, types.NewWrappedError(
- errorMessage,
- &CurrentNoMapError{},
- )
+ bs := srvs.Map
+ if bs == nil {
+ return nil, errors.Errorf("srvs.Map is nil")
}
- server, exists := bases[currentServerURL]
- if !exists || server == nil {
- return nil, types.NewWrappedError(
- errorMessage,
- &CurrentNotFoundError{},
- )
+ if srv, ok := bs[srvs.CurrentURL]; !ok || srv == nil {
+ return nil, errors.Errorf("server not found")
+ } else {
+ return srv, nil
}
- return server, nil
}
-func (servers *Servers) addInstituteAndCustom(
+func (ss *Servers) addInstituteAndCustom(
discoServer *types.DiscoveryServer,
isCustom bool,
) (Server, error) {
url := discoServer.BaseURL
- errorMessage := fmt.Sprintf("failed adding institute access server: %s", url)
- toAddServers := &servers.InstituteServers
- serverType := InstituteAccessServerType
+ srvs := &ss.InstituteServers
+ srvType := InstituteAccessServerType
if isCustom {
- toAddServers = &servers.CustomServers
- serverType = CustomServerType
+ srvs = &ss.CustomServers
+ srvType = CustomServerType
}
- if toAddServers.Map == nil {
- toAddServers.Map = make(map[string]*InstituteAccessServer)
+ if srvs.Map == nil {
+ srvs.Map = make(map[string]*InstituteAccessServer)
}
- server, exists := toAddServers.Map[url]
+ srv, ok := srvs.Map[url]
// initialize the server if it doesn't exist yet
- if !exists {
- server = &InstituteAccessServer{}
+ if !ok {
+ srv = &InstituteAccessServer{}
}
- instituteInitErr := server.init(
- url,
- discoServer.DisplayName,
- discoServer.Type,
- discoServer.SupportContact,
- )
- if instituteInitErr != nil {
- return nil, types.NewWrappedError(errorMessage, instituteInitErr)
+ if err := srv.init(url, discoServer.DisplayName, discoServer.Type, discoServer.SupportContact); err != nil {
+ return nil, err
}
- toAddServers.Map[url] = server
- servers.IsType = serverType
- return server, nil
+ srvs.Map[url] = srv
+ ss.IsType = srvType
+ return srv, nil
}
-func (servers *Servers) AddInstituteAccessServer(
+func (ss *Servers) AddInstituteAccessServer(
instituteServer *types.DiscoveryServer,
) (Server, error) {
- return servers.addInstituteAndCustom(instituteServer, false)
+ return ss.addInstituteAndCustom(instituteServer, false)
}
-func (servers *Servers) AddCustomServer(
+func (ss *Servers) AddCustomServer(
customServer *types.DiscoveryServer,
) (Server, error) {
- return servers.addInstituteAndCustom(customServer, true)
+ return ss.addInstituteAndCustom(customServer, true)
}
-func (servers *Servers) GetSecureLocation() string {
- return servers.SecureInternetHomeServer.CurrentLocation
+func (ss *Servers) GetSecureLocation() string {
+ return ss.SecureInternetHomeServer.CurrentLocation
}
-func (servers *Servers) SetSecureLocation(
+func (ss *Servers) SetSecureLocation(
chosenLocationServer *types.DiscoveryServer,
) error {
- errorMessage := "failed to set secure location"
// Make sure to add the current location
- _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer)
- if addLocationErr != nil {
- return types.NewWrappedError(errorMessage, addLocationErr)
+ if _, err := ss.SecureInternetHomeServer.addLocation(chosenLocationServer); err != nil {
+ return err
}
- servers.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode
+ ss.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode
return nil
}