From 3ac1d35257b56cca92ad0eb7f4d18abb366cf105 Mon Sep 17 00:00:00 2001 From: Aleksandar Pesic Date: Sun, 4 Dec 2022 21:48:20 +0100 Subject: simplify error handling fixes #6 Signed-off-by: Aleksandar Pesic --- internal/server/api.go | 214 +++++++++++-------------- internal/server/base.go | 29 ++-- internal/server/custom.go | 41 ++--- internal/server/instituteaccess.go | 85 +++++----- internal/server/secureinternet.go | 151 +++++++----------- internal/server/server.go | 318 ++++++++++++++----------------------- internal/server/servers.go | 117 ++++++-------- 7 files changed, 380 insertions(+), 575 deletions(-) (limited to 'internal/server') diff --git a/internal/server/api.go b/internal/server/api.go index 21ba6f4..dfa8e14 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -2,7 +2,6 @@ package server import ( "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -10,130 +9,114 @@ import ( "time" httpw "github.com/eduvpn/eduvpn-common/internal/http" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) func APIGetEndpoints(baseURL string) (*Endpoints, error) { - errorMessage := "failed getting server endpoints" - url, urlErr := url.Parse(baseURL) - if urlErr != nil { - return nil, types.NewWrappedError(errorMessage, urlErr) + u, err := url.Parse(baseURL) + if err != nil { + return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) } - wellKnownPath := "/.well-known/vpn-user-portal" + wk := "/.well-known/vpn-user-portal" - url.Path = path.Join(url.Path, wellKnownPath) - _, body, bodyErr := httpw.Get(url.String()) - - if bodyErr != nil { - return nil, types.NewWrappedError(errorMessage, bodyErr) + u.Path = path.Join(u.Path, wk) + _, body, err := httpw.Get(u.String()) + if err != nil { + return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) } - endpoints := &Endpoints{} - jsonErr := json.Unmarshal(body, endpoints) - - if jsonErr != nil { - return nil, types.NewWrappedError(errorMessage, jsonErr) + ep := &Endpoints{} + if err = json.Unmarshal(body, ep); err != nil { + return nil, errors.WrapPrefix(err, "failed getting server endpoints", 0) } - return endpoints, nil + return ep, nil } func apiAuthorized( - server Server, + srv Server, method string, endpoint string, opts *httpw.OptionalParams, ) (http.Header, []byte, error) { - errorMessage := "failed API authorized" // Ensure optional is not nil as we will fill it with headers if opts == nil { opts = &httpw.OptionalParams{} } - base, baseErr := server.Base() - - if baseErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, baseErr) + b, err := srv.Base() + if err != nil { + return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0) } // Join the paths - url, urlErr := url.Parse(base.Endpoints.API.V3.API) - if urlErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, urlErr) + u, err := url.Parse(b.Endpoints.API.V3.API) + if err != nil { + return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0) } - url.Path = path.Join(url.Path, endpoint) + u.Path = path.Join(u.Path, endpoint) // Make sure the tokens are valid, this will return an error if re-login is needed - token, tokenErr := HeaderToken(server) - if tokenErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, tokenErr) + t, err := HeaderToken(srv) + if err != nil { + return nil, nil, errors.WrapPrefix(err, "failed API authorized", 0) } - headerKey := "Authorization" - headerValue := fmt.Sprintf("Bearer %s", token) + key := "Authorization" + val := fmt.Sprintf("Bearer %s", t) if opts.Headers != nil { - opts.Headers.Add(headerKey, headerValue) + opts.Headers.Add(key, val) } else { - opts.Headers = http.Header{headerKey: {headerValue}} + opts.Headers = http.Header{key: {val}} } - return httpw.MethodWithOpts(method, url.String(), opts) + return httpw.MethodWithOpts(method, u.String(), opts) } func apiAuthorizedRetry( - server Server, + srv Server, method string, endpoint string, opts *httpw.OptionalParams, ) (http.Header, []byte, error) { - errorMessage := "failed authorized API retry" - header, body, bodyErr := apiAuthorized(server, method, endpoint, opts) - - if bodyErr != nil { - var error *httpw.StatusError - - // Only retry authorized if we get a HTTP 401 - if errors.As(bodyErr, &error) && error.Status == 401 { - // Mark the token as expired and retry so we trigger the refresh flow - MarkTokenExpired(server) - retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts) - if retryErr != nil { - return nil, nil, types.NewWrappedError(errorMessage, retryErr) - } - return retryHeader, retryBody, nil - } - return nil, nil, types.NewWrappedError(errorMessage, bodyErr) - } - return header, body, nil -} - -func APIInfo(server Server) error { - errorMessage := "failed API /info" - _, body, bodyErr := apiAuthorizedRetry(server, http.MethodGet, "/info", nil) - if bodyErr != nil { - return types.NewWrappedError(errorMessage, bodyErr) + h, body, err := apiAuthorized(srv, method, endpoint, opts) + if err == nil { + return h, body, nil } - structure := ProfileInfo{} - jsonErr := json.Unmarshal(body, &structure) - if jsonErr != nil { - return types.NewWrappedError(errorMessage, jsonErr) + statErr := &httpw.StatusError{} + // Only retry authorized if we get an HTTP 401 + if errors.As(err, &statErr) && statErr.Status == 401 { + // Mark the token as expired and retry, so we trigger the refresh flow + MarkTokenExpired(srv) + h, body, err = apiAuthorized(srv, method, endpoint, opts) } + return h, body, err +} - base, baseErr := server.Base() +func APIInfo(srv Server) error { + _, body, err := apiAuthorizedRetry(srv, http.MethodGet, "/info", nil) + if err != nil { + return err + } + pi := ProfileInfo{} + if err = json.Unmarshal(body, &pi); err != nil { + return errors.WrapPrefix(err, "failed API /info", 0) + } - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) + b, err := srv.Base() + if err != nil { + return err } // Store the profiles and make sure that the current profile is not overwritten - previousProfile := base.Profiles.Current - base.Profiles = structure - base.Profiles.Current = previousProfile + prev := b.Profiles.Current + b.Profiles = pi + b.Profiles.Current = prev return nil } // see https://github.com/eduvpn/documentation/blob/v3/API.md#request-1 -func GetPreferTCPString(preferTCP bool) string { +func boolToYesNo(preferTCP bool) string { if preferTCP { return "yes" } @@ -141,88 +124,77 @@ func GetPreferTCPString(preferTCP bool) string { } func APIConnectWireguard( - server Server, + srv Server, profileID string, pubkey string, preferTCP bool, - supportsOpenVPN bool, + openVPNSupport bool, ) (string, string, time.Time, error) { - errorMessage := "failed obtaining a WireGuard configuration" - headers := http.Header{ + hdrs := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-wireguard-profile"}, } // This profile also supports OpenVPN // Indicate that we also accept OpenVPN profiles - if supportsOpenVPN { - headers.Add("accept", "application/x-openvpn-profile") + if openVPNSupport { + hdrs.Add("accept", "application/x-openvpn-profile") } - urlForm := url.Values{ + vals := url.Values{ "profile_id": {profileID}, "public_key": {pubkey}, - "prefer_tcp": {GetPreferTCPString(preferTCP)}, + "prefer_tcp": {boolToYesNo(preferTCP)}, } - header, connectBody, connectErr := apiAuthorizedRetry( - server, - http.MethodPost, - "/connect", - &httpw.OptionalParams{Headers: headers, Body: urlForm}, - ) - if connectErr != nil { - return "", "", time.Time{}, types.NewWrappedError( - errorMessage, - connectErr, - ) + h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect", + &httpw.OptionalParams{Headers: hdrs, Body: vals}) + if err != nil { + return "", "", time.Time{}, err } - expires := header.Get("expires") + exp := h.Get("expires") - pTime, pTimeErr := http.ParseTime(expires) - if pTimeErr != nil { - return "", "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr) + ptm, err := http.ParseTime(exp) + if err != nil { + return "", "", time.Time{}, errors.WrapPrefix(err, "failed obtaining a WireGuard configuration", 0) } - contentType := header.Get("content-type") - - content := "openvpn" - if contentType == "application/x-wireguard-profile" { - content = "wireguard" + ct := h.Get("content-type") + c := "openvpn" + if ct == "application/x-wireguard-profile" { + c = "wireguard" } - return string(connectBody), content, pTime, nil + + return string(body), c, ptm, nil } -func APIConnectOpenVPN(server Server, profileID string, preferTCP bool) (string, time.Time, error) { - errorMessage := "failed obtaining an OpenVPN configuration" - headers := http.Header{ +func APIConnectOpenVPN(srv Server, profileID string, preferTCP bool) (string, time.Time, error) { + hdrs := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-openvpn-profile"}, } - urlForm := url.Values{ + vals := url.Values{ "profile_id": {profileID}, - "prefer_tcp": {GetPreferTCPString(preferTCP)}, + "prefer_tcp": {boolToYesNo(preferTCP)}, } - header, connectBody, connectErr := apiAuthorizedRetry( - server, - http.MethodPost, - "/connect", - &httpw.OptionalParams{Headers: headers, Body: urlForm}, - ) - if connectErr != nil { - return "", time.Time{}, types.NewWrappedError(errorMessage, connectErr) + h, body, err := apiAuthorizedRetry(srv, http.MethodPost, "/connect", + &httpw.OptionalParams{Headers: hdrs, Body: vals}) + if err != nil { + return "", time.Time{}, err } - expires := header.Get("expires") - pTime, pTimeErr := http.ParseTime(expires) - if pTimeErr != nil { - return "", time.Time{}, types.NewWrappedError(errorMessage, pTimeErr) + exp := h.Get("expires") + ptm, err := http.ParseTime(exp) + if err != nil { + return "", time.Time{}, errors.WrapPrefix(err, "failed obtaining an OpenVPN configuration", 0) } - return string(connectBody), pTime, nil + + return string(body), ptm, nil } +// APIDisconnect disconnects from the API. // This needs no further return value as it's best effort. func APIDisconnect(server Server) { _, _, _ = apiAuthorized(server, http.MethodPost, "/disconnect", nil) diff --git a/internal/server/base.go b/internal/server/base.go index bb88eb3..81049cf 100644 --- a/internal/server/base.go +++ b/internal/server/base.go @@ -2,11 +2,9 @@ package server import ( "time" - - "github.com/eduvpn/eduvpn-common/types" ) -// The base type for servers. +// Base is the base type for servers. type Base struct { URL string `json:"base_url"` DisplayName map[string]string `json:"display_name"` @@ -18,28 +16,27 @@ type Base struct { Type string `json:"server_type"` } -func (base *Base) InitializeEndpoints() error { - errorMessage := "failed initializing endpoints" - endpoints, endpointsErr := APIGetEndpoints(base.URL) - if endpointsErr != nil { - return types.NewWrappedError(errorMessage, endpointsErr) +func (b *Base) InitializeEndpoints() error { + ep, err := APIGetEndpoints(b.URL) + if err != nil { + return err } - base.Endpoints = *endpoints + b.Endpoints = *ep return nil } -func (base *Base) ValidProfiles(clientSupportsWireguard bool) ProfileInfo { - var validProfiles []Profile - for _, profile := range base.Profiles.Info.ProfileList { +func (b *Base) ValidProfiles(wireguardSupport bool) ProfileInfo { + var vps []Profile + for _, p := range b.Profiles.Info.ProfileList { // Not a valid profile because it does not support openvpn // Also the client does not support wireguard - if !profile.supportsOpenVPN() && !clientSupportsWireguard { + if !p.supportsOpenVPN() && !wireguardSupport { continue } - validProfiles = append(validProfiles, profile) + vps = append(vps, p) } return ProfileInfo{ - Current: base.Profiles.Current, - Info: ProfileListInfo{ProfileList: validProfiles}, + Current: b.Profiles.Current, + Info: ProfileListInfo{ProfileList: vps}, } } diff --git a/internal/server/custom.go b/internal/server/custom.go index d376727..bf0b230 100644 --- a/internal/server/custom.go +++ b/internal/server/custom.go @@ -1,42 +1,35 @@ package server import ( - "errors" - "fmt" - - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) -func (servers *Servers) SetCustomServer(server Server) error { - errorMessage := "failed setting custom server" - base, baseErr := server.Base() - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) +func (ss *Servers) SetCustomServer(server Server) error { + b, err := server.Base() + if err != nil { + return err } - if base.Type != "custom_server" { - return types.NewWrappedError(errorMessage, errors.New("not a custom server")) + if b.Type != "custom_server" { + return errors.WrapPrefix(err, "not a custom server", 0) } - if _, ok := servers.CustomServers.Map[base.URL]; ok { - servers.CustomServers.CurrentURL = base.URL - servers.IsType = CustomServerType + if _, ok := ss.CustomServers.Map[b.URL]; ok { + ss.CustomServers.CurrentURL = b.URL + ss.IsType = CustomServerType } else { - return types.NewWrappedError(errorMessage, errors.New("not a custom server")) + return errors.Errorf("not a custom server") } return nil } -func (servers *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) { - if server, ok := servers.CustomServers.Map[url]; ok { - return server, nil +func (ss *Servers) GetCustomServer(url string) (*InstituteAccessServer, error) { + if srv, ok := ss.CustomServers.Map[url]; ok { + return srv, nil } - return nil, types.NewWrappedError( - "failed to get institute access server", - fmt.Errorf("no custom server with URL: %s", url), - ) + return nil, errors.Errorf("failed to get institute access server - no custom server with URL '%s'", url) } -func (servers *Servers) RemoveCustomServer(url string) { - servers.CustomServers.Remove(url) +func (ss *Servers) RemoveCustomServer(url string) { + ss.CustomServers.Remove(url) } diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index 9b6f735..56ed1cf 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -1,14 +1,10 @@ package server import ( - "errors" - "fmt" - "github.com/eduvpn/eduvpn-common/internal/oauth" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) -// An instute access server. type InstituteAccessServer struct { // An instute access server has its own OAuth Auth oauth.OAuth `json:"oauth"` @@ -22,80 +18,75 @@ type InstituteAccessServers struct { CurrentURL string `json:"current_url"` } -func (servers *Servers) SetInstituteAccess(server Server) error { - errorMessage := "failed setting institute access server" - base, baseErr := server.Base() - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) +func (ss *Servers) SetInstituteAccess(srv Server) error { + b, err := srv.Base() + if err != nil { + return err } - if base.Type != "institute_access" { - return types.NewWrappedError(errorMessage, errors.New("not an institute access server")) + if b.Type != "institute_access" { + return errors.Errorf("not an institute access server") } - if _, ok := servers.InstituteServers.Map[base.URL]; ok { - servers.InstituteServers.CurrentURL = base.URL - servers.IsType = InstituteAccessServerType + if _, ok := ss.InstituteServers.Map[b.URL]; ok { + ss.InstituteServers.CurrentURL = b.URL + ss.IsType = InstituteAccessServerType } else { - return types.NewWrappedError(errorMessage, errors.New("no such institute access server")) + return errors.Errorf("no such institute access server") } return nil } -func (servers *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) { - if server, ok := servers.InstituteServers.Map[url]; ok { - return server, nil +func (ss *Servers) GetInstituteAccess(url string) (*InstituteAccessServer, error) { + if srv, ok := ss.InstituteServers.Map[url]; ok { + return srv, nil } - return nil, types.NewWrappedError( - "failed to get institute access server", - fmt.Errorf("no institute access server with URL: %s", url), - ) + return nil, errors.Errorf("no institute access server with URL: %s", url) } -func (servers *Servers) RemoveInstituteAccess(url string) { - servers.InstituteServers.Remove(url) +func (ss *Servers) RemoveInstituteAccess(url string) { + ss.InstituteServers.Remove(url) } -func (servers *InstituteAccessServers) Remove(url string) { +func (iass *InstituteAccessServers) Remove(url string) { // Reset the current url - if servers.CurrentURL == url { - servers.CurrentURL = "" + if iass.CurrentURL == url { + iass.CurrentURL = "" } // Delete the url from the map - delete(servers.Map, url) + delete(iass.Map, url) } -func (institute *InstituteAccessServer) TemplateAuth() func(string) string { +func (ias *InstituteAccessServer) TemplateAuth() func(string) string { return func(authURL string) string { return authURL } } -func (institute *InstituteAccessServer) Base() (*Base, error) { - return &institute.Basic, nil +func (ias *InstituteAccessServer) Base() (*Base, error) { + return &ias.Basic, nil } -func (institute *InstituteAccessServer) OAuth() *oauth.OAuth { - return &institute.Auth +func (ias *InstituteAccessServer) OAuth() *oauth.OAuth { + return &ias.Auth } -func (institute *InstituteAccessServer) init( +func (ias *InstituteAccessServer) init( url string, - displayName map[string]string, - serverType string, + name map[string]string, + srvType string, supportContact []string, ) error { - errorMessage := fmt.Sprintf("failed initializing server %s", url) - institute.Basic.URL = url - institute.Basic.DisplayName = displayName - institute.Basic.SupportContact = supportContact - institute.Basic.Type = serverType - endpointsErr := institute.Basic.InitializeEndpoints() - if endpointsErr != nil { - return types.NewWrappedError(errorMessage, endpointsErr) + ias.Basic.URL = url + ias.Basic.DisplayName = name + ias.Basic.SupportContact = supportContact + ias.Basic.Type = srvType + err := ias.Basic.InitializeEndpoints() + if err != nil { + return err } - API := institute.Basic.Endpoints.API.V3 - institute.Auth.Init(url, API.Authorization, API.Token) + API := ias.Basic.Endpoints.API.V3 + ias.Auth.Init(url, API.Authorization, API.Token) return nil } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 998390d..12263a6 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -1,15 +1,13 @@ package server import ( - "errors" - "fmt" - "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/util" "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) -// A secure internet server which has its own OAuth tokens +// SecureInternetHomeServer secure internet server which has its own OAuth tokens // It specifies the current location url it is connected to. type SecureInternetHomeServer struct { Auth oauth.OAuth `json:"oauth"` @@ -24,150 +22,111 @@ type SecureInternetHomeServer struct { CurrentLocation string `json:"current_location"` } -func (servers *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) { - if !servers.HasSecureLocation() { - return nil, errors.New("no secure internet home server") +func (ss *Servers) GetSecureInternetHomeServer() (*SecureInternetHomeServer, error) { + if !ss.HasSecureLocation() { + return nil, errors.Errorf("no secure internet home server") } - return &servers.SecureInternetHomeServer, nil + return &ss.SecureInternetHomeServer, nil } -func (servers *Servers) SetSecureInternet(server Server) error { - errorMessage := "failed setting secure internet server" - base, baseErr := server.Base() - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) +func (ss *Servers) SetSecureInternet(server Server) error { + b, err := server.Base() + if err != nil { + return err } - if base.Type != "secure_internet" { - return types.NewWrappedError(errorMessage, errors.New("not a secure internet server")) + if b.Type != "secure_internet" { + return errors.Errorf("not a secure internet server") } // The location should already be configured // TODO: check for location? - servers.IsType = SecureInternetServerType + ss.IsType = SecureInternetServerType return nil } -func (servers *Servers) RemoveSecureInternet() { +func (ss *Servers) RemoveSecureInternet() { // Empty out the struct - servers.SecureInternetHomeServer = SecureInternetHomeServer{} + ss.SecureInternetHomeServer = SecureInternetHomeServer{} // If the current server is secure internet, default to custom server - if servers.IsType == SecureInternetServerType { - servers.IsType = CustomServerType + if ss.IsType == SecureInternetServerType { + ss.IsType = CustomServerType } } -func (server *SecureInternetHomeServer) TemplateAuth() func(string) string { +func (s *SecureInternetHomeServer) TemplateAuth() func(string) string { return func(authURL string) string { - return util.ReplaceWAYF(server.AuthorizationTemplate, authURL, server.HomeOrganizationID) + return util.ReplaceWAYF(s.AuthorizationTemplate, authURL, s.HomeOrganizationID) } } -func (server *SecureInternetHomeServer) Base() (*Base, error) { - errorMessage := "failed getting current secure internet home base" - if server.BaseMap == nil { - return nil, types.NewWrappedError( - errorMessage, - &SecureInternetMapNotFoundError{}, - ) +func (s *SecureInternetHomeServer) Base() (*Base, error) { + if s.BaseMap == nil { + return nil, errors.Errorf("secure internet map not found") } - base, exists := server.BaseMap[server.CurrentLocation] - - if !exists { - return nil, types.NewWrappedError( - errorMessage, - &SecureInternetBaseNotFoundError{Current: server.CurrentLocation}, - ) + b, ok := s.BaseMap[s.CurrentLocation] + if !ok { + return nil, errors.Errorf("secure internet base with location '%s' not found", s.CurrentLocation) } - return base, nil + return b, nil } -func (server *SecureInternetHomeServer) OAuth() *oauth.OAuth { - return &server.Auth +func (s *SecureInternetHomeServer) OAuth() *oauth.OAuth { + return &s.Auth } -func (servers *Servers) HasSecureLocation() bool { - return servers.SecureInternetHomeServer.CurrentLocation != "" +func (ss *Servers) HasSecureLocation() bool { + return ss.SecureInternetHomeServer.CurrentLocation != "" } -func (server *SecureInternetHomeServer) addLocation( - locationServer *types.DiscoveryServer, -) (*Base, error) { - errorMessage := "failed adding a location" +func (s *SecureInternetHomeServer) addLocation(locSrv *types.DiscoveryServer) (*Base, error) { // Initialize the base map if it is non-nil - if server.BaseMap == nil { - server.BaseMap = make(map[string]*Base) + if s.BaseMap == nil { + s.BaseMap = make(map[string]*Base) } // Add the location to the base map - base, exists := server.BaseMap[locationServer.CountryCode] - - if !exists || base == nil { + b, ok := s.BaseMap[locSrv.CountryCode] + if !ok || b == nil { // Create the base to be added to the map - base = &Base{} - base.URL = locationServer.BaseURL - base.DisplayName = server.DisplayName - base.SupportContact = locationServer.SupportContact - base.Type = "secure_internet" - endpointsErr := base.InitializeEndpoints() - if endpointsErr != nil { - return nil, types.NewWrappedError(errorMessage, endpointsErr) + b = &Base{} + b.URL = locSrv.BaseURL + b.DisplayName = s.DisplayName + b.SupportContact = locSrv.SupportContact + b.Type = "secure_internet" + if err := b.InitializeEndpoints(); err != nil { + return nil, err } } // Ensure it is in the map - server.BaseMap[locationServer.CountryCode] = base - return base, nil + s.BaseMap[locSrv.CountryCode] = b + return b, nil } // Initializes the home server and adds its own location. -func (server *SecureInternetHomeServer) init( - homeOrg *types.DiscoveryOrganization, - homeLocation *types.DiscoveryServer, -) error { - errorMessage := "failed initializing secure internet home server" - - if server.HomeOrganizationID != homeOrg.OrgID { +func (s *SecureInternetHomeServer) init( + homeOrg *types.DiscoveryOrganization, homeLoc *types.DiscoveryServer) error { + if s.HomeOrganizationID != homeOrg.OrgID { // New home organisation, clear everything - *server = SecureInternetHomeServer{} + *s = SecureInternetHomeServer{} } // Make sure to set the organization ID - server.HomeOrganizationID = homeOrg.OrgID - server.DisplayName = homeOrg.DisplayName + s.HomeOrganizationID = homeOrg.OrgID + s.DisplayName = homeOrg.DisplayName // Make sure to set the authorization URL template - server.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate + s.AuthorizationTemplate = homeLoc.AuthenticationURLTemplate - base, baseErr := server.addLocation(homeLocation) - - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) + b, err := s.addLocation(homeLoc) + if err != nil { + return err } // Make sure oauth contains our endpoints - server.Auth.Init(base.URL, base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token) + s.Auth.Init(b.URL, b.Endpoints.API.V3.Authorization, b.Endpoints.API.V3.Token) return nil } - -type SecureInternetHomeNotFoundError struct{} - -func (e *SecureInternetHomeNotFoundError) Error() string { - return "failed to get secure internet home server, not found" -} - -type SecureInternetMapNotFoundError struct{} - -func (e *SecureInternetMapNotFoundError) Error() string { - return "secure internet map not found" -} - -type SecureInternetBaseNotFoundError struct { - Current string -} - -func (e *SecureInternetBaseNotFoundError) Error() string { - return fmt.Sprintf("secure internet base not found with current location: %s", e.Current) -} diff --git a/internal/server/server.go b/internal/server/server.go index 95244d5..de0fa9a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,13 +1,11 @@ package server import ( - "errors" - "fmt" "time" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/wireguard" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) type Type int8 @@ -21,10 +19,10 @@ const ( type Server interface { OAuth() *oauth.OAuth - // Get the authorization URL template function + // TemplateAuth returns the authorization URL template function TemplateAuth() func(string) string - // Gets the server base + // Base returns the server base Base() (*Base, error) } @@ -34,7 +32,7 @@ type EndpointList struct { Token string `json:"token_endpoint"` } -// Struct that defines the json format for /.well-known/vpn-user-portal". +// Endpoints defines the json format for /.well-known/vpn-user-portal". type Endpoints struct { API struct { V2 EndpointList `json:"http://eduvpn.org/api#2"` @@ -43,310 +41,226 @@ type Endpoints struct { V string `json:"v"` } -func ShouldRenewButton(server Server) bool { - base, baseErr := server.Base() - - if baseErr != nil { +func ShouldRenewButton(srv Server) bool { + b, err := srv.Base() + if err != nil { // FIXME: Log error here? return false } // Get current time - current := time.Now() + now := time.Now() // Session is expired - if !current.Before(base.EndTime) { + if !now.Before(b.EndTime) { return true } // 30 minutes have not passed - if !current.After(base.StartTime.Add(30 * time.Minute)) { + if !now.After(b.StartTime.Add(30 * time.Minute)) { return false } // Session will not expire today - if !current.Add(24 * time.Hour).After(base.EndTime) { + if !now.Add(24 * time.Hour).After(b.EndTime) { return false } // Session duration is less than 24 hours but not 75% has passed - duration := base.EndTime.Sub(base.StartTime) - percentTime := base.StartTime.Add((duration / 4) * 3) - if duration < time.Duration(24*time.Hour) && !current.After(percentTime) { + d := b.EndTime.Sub(b.StartTime) + pct := b.StartTime.Add((d / 4) * 3) + if d < 24*time.Hour && !now.After(pct) { return false } return true } -func OAuthURL(server Server, name string) (string, error) { - return server.OAuth().AuthURL(name, server.TemplateAuth()) +func OAuthURL(srv Server, name string) (string, error) { + return srv.OAuth().AuthURL(name, srv.TemplateAuth()) } -func OAuthExchange(server Server) error { - return server.OAuth().Exchange() +func OAuthExchange(srv Server) error { + return srv.OAuth().Exchange() } -func HeaderToken(server Server) (string, error) { - token, tokenErr := server.OAuth().AccessToken() - if tokenErr != nil { - return "", types.NewWrappedError("failed getting server token for HTTP Header", tokenErr) - } - return token, nil +func HeaderToken(srv Server) (string, error) { + return srv.OAuth().AccessToken() } -func MarkTokenExpired(server Server) { - server.OAuth().SetTokenExpired() +func MarkTokenExpired(srv Server) { + srv.OAuth().SetTokenExpired() } -func MarkTokensForRenew(server Server) { - server.OAuth().SetTokenRenew() +func MarkTokensForRenew(srv Server) { + srv.OAuth().SetTokenRenew() } -func NeedsRelogin(server Server) bool { - _, tokenErr := HeaderToken(server) - return tokenErr != nil +func NeedsRelogin(srv Server) bool { + _, err := HeaderToken(srv) + return err != nil } -func CancelOAuth(server Server) { - server.OAuth().Cancel() +func CancelOAuth(srv Server) { + srv.OAuth().Cancel() } -func CurrentProfile(server Server) (*Profile, error) { - errorMessage := "failed getting current profile" - base, baseErr := server.Base() - - if baseErr != nil { - return nil, types.NewWrappedError(errorMessage, baseErr) +func CurrentProfile(srv Server) (*Profile, error) { + b, err := srv.Base() + if err != nil { + return nil, err } - profileID := base.Profiles.Current - for _, profile := range base.Profiles.Info.ProfileList { - if profile.ID == profileID { + pid := b.Profiles.Current + for _, profile := range b.Profiles.Info.ProfileList { + if profile.ID == pid { return &profile, nil } } - return nil, types.NewWrappedError( - errorMessage, - &CurrentProfileNotFoundError{ProfileID: profileID}, - ) + return nil, errors.Errorf("profile not found: " + pid) } -func ValidProfiles(server Server, clientSupportsWireguard bool) (*ProfileInfo, error) { - errorMessage := "failed to get valid profiles" +func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) { // No error wrapping here otherwise we wrap it too much - base, baseErr := server.Base() - if baseErr != nil { - return nil, types.NewWrappedError(errorMessage, baseErr) + b, err := srv.Base() + if err != nil { + return nil, err } - profiles := base.ValidProfiles(clientSupportsWireguard) - if len(profiles.Info.ProfileList) == 0 { - return nil, types.NewWrappedError( - errorMessage, - errors.New("no profiles found with supported protocols"), - ) + ps := b.ValidProfiles(wireguardSupport) + if len(ps.Info.ProfileList) == 0 { + return nil, errors.Errorf("no profiles found with supported protocols") } - return &profiles, nil + return &ps, nil } -func wireguardGetConfig( - server Server, - preferTCP bool, - supportsOpenVPN bool, -) (string, string, error) { - errorMessage := "failed getting server WireGuard configuration" - base, baseErr := server.Base() - - if baseErr != nil { - return "", "", types.NewWrappedError(errorMessage, baseErr) +func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string, string, error) { + b, err := srv.Base() + if err != nil { + return "", "", err } - profileID := base.Profiles.Current - wireguardKey, wireguardErr := wireguard.GenerateKey() - - if wireguardErr != nil { - return "", "", types.NewWrappedError(errorMessage, wireguardErr) + pid := b.Profiles.Current + key, err := wireguard.GenerateKey() + if err != nil { + return "", "", err } - wireguardPublicKey := wireguardKey.PublicKey().String() - config, content, expires, configErr := APIConnectWireguard( - server, - profileID, - wireguardPublicKey, - preferTCP, - supportsOpenVPN, - ) - - if configErr != nil { - return "", "", types.NewWrappedError(errorMessage, configErr) + pub := key.PublicKey().String() + cfg, ct, exp, err := APIConnectWireguard(srv, pid, pub, preferTCP, openVPNSupport) + if err != nil { + return "", "", err } // Store start and end time - base.StartTime = time.Now() - base.EndTime = expires + b.StartTime = time.Now() + b.EndTime = exp - if content == "wireguard" { + if ct == "wireguard" { // This needs the go code a way to identify a connection // Use the uuid of the connection e.g. on Linux // This needs the client code to call the go code - config = wireguard.ConfigAddKey(config, wireguardKey) + cfg = wireguard.ConfigAddKey(cfg, key) } - return config, content, nil + return cfg, ct, nil } -func openVPNGetConfig(server Server, preferTCP bool) (string, string, error) { - errorMessage := "failed getting server OpenVPN configuration" - base, baseErr := server.Base() - - if baseErr != nil { - return "", "", types.NewWrappedError(errorMessage, baseErr) +func openVPNGetConfig(srv Server, preferTCP bool) (string, string, error) { + b, err := srv.Base() + if err != nil { + return "", "", err } - profileID := base.Profiles.Current - configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profileID, preferTCP) + pid := b.Profiles.Current + cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP) // Store start and end time - base.StartTime = time.Now() - base.EndTime = expires + b.StartTime = time.Now() + b.EndTime = exp - if configErr != nil { - return "", "", types.NewWrappedError(errorMessage, configErr) + if err != nil { + return "", "", err } - return configOpenVPN, "openvpn", nil + return cfg, "openvpn", nil } -func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) { - errorMessage := "failed has valid profile check" - +func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { // Get new profiles using the info call // This does not override the current profile - infoErr := APIInfo(server) - if infoErr != nil { - return false, types.NewWrappedError(errorMessage, infoErr) + err := APIInfo(srv) + if err != nil { + return false, err } - base, baseErr := server.Base() - if baseErr != nil { - return false, types.NewWrappedError(errorMessage, baseErr) + b, err := srv.Base() + if err != nil { + return false, err } // If there was a profile chosen and it doesn't exist anymore, reset it - if base.Profiles.Current != "" { - _, existsProfileErr := CurrentProfile(server) - if existsProfileErr != nil { - base.Profiles.Current = "" + if b.Profiles.Current != "" { + if _, err = CurrentProfile(srv); err != nil { + b.Profiles.Current = "" } } - // Set the current profile if there is only one profile or profile is already selected - if len(base.Profiles.Info.ProfileList) == 1 || base.Profiles.Current != "" { - // Set the first profile if none is selected - if base.Profiles.Current == "" { - base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID - } - profile, profileErr := CurrentProfile(server) - // shouldn't happen - if profileErr != nil { - return false, types.NewWrappedError(errorMessage, profileErr) - } - // Profile does not support OpenVPN but the client also doesn't support WireGuard - if !profile.supportsOpenVPN() && !clientSupportsWireguard { - return false, nil - } - return true, nil + if len(b.Profiles.Info.ProfileList) != 1 && b.Profiles.Current == "" { + return false, nil } - return false, nil + // Set the current profile if there is only one profile or profile is already selected + // Set the first profile if none is selected + if b.Profiles.Current == "" { + b.Profiles.Current = b.Profiles.Info.ProfileList[0].ID + } + p, err := CurrentProfile(srv) + // shouldn't happen + if err != nil { + return false, err + } + // Profile does not support OpenVPN but the client also doesn't support WireGuard + if !p.supportsOpenVPN() && !wireguardSupport { + return false, nil + } + return true, nil } -func RefreshEndpoints(server Server) error { - errorMessage := "failed to refresh server endpoints" - +func RefreshEndpoints(srv Server) error { // Re-initialize the endpoints // TODO: Make this a warning instead? - base, baseErr := server.Base() - if baseErr != nil { - return types.NewWrappedError(errorMessage, baseErr) - } - - endpointsErr := base.InitializeEndpoints() - if endpointsErr != nil { - return types.NewWrappedError(errorMessage, endpointsErr) + b, err := srv.Base() + if err != nil { + return err } - return nil + return b.InitializeEndpoints() } -func Config(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) { - errorMessage := "failed getting an OpenVPN/WireGuard configuration" - - profile, profileErr := CurrentProfile(server) - if profileErr != nil { - return "", "", types.NewWrappedError(errorMessage, profileErr) +func Config(server Server, wireguardSupport bool, preferTCP bool) (string, string, error) { + p, err := CurrentProfile(server) + if err != nil { + return "", "", err } - supportsOpenVPN := profile.supportsOpenVPN() - supportsWireguard := profile.supportsWireguard() && clientSupportsWireguard - - var config string - var configType string - var configErr error + ovpn := p.supportsOpenVPN() + wg := p.supportsWireguard() && wireguardSupport switch { // The config supports wireguard and optionally openvpn - case supportsWireguard: + case wg: // A wireguard connect call needs to generate a wireguard key and add it to the config // Also the server could send back an OpenVPN config if it supports OpenVPN - config, configType, configErr = wireguardGetConfig(server, preferTCP, supportsOpenVPN) + return wireguardGetConfig(server, preferTCP, ovpn) // The config only supports OpenVPN - case supportsOpenVPN: - config, configType, configErr = openVPNGetConfig(server, preferTCP) + case ovpn: + return openVPNGetConfig(server, preferTCP) // The config supports no available protocol because the profile only supports WireGuard but the client doesn't default: - return "", "", types.NewWrappedError(errorMessage, errors.New("no supported protocol found")) + return "", "", errors.Errorf("no supported protocol found") } - - if configErr != nil { - return "", "", types.NewWrappedError(errorMessage, configErr) - } - - return config, configType, nil } func Disconnect(server Server) { APIDisconnect(server) } - -type CurrentProfileNotFoundError struct { - ProfileID string -} - -func (e *CurrentProfileNotFoundError) Error() string { - return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID) -} - -type ConfigPreferTCPError struct{} - -func (e *ConfigPreferTCPError) Error() string { - return "failed to get config, prefer TCP is on but the server does not support OpenVPN" -} - -type EmptyURLError struct{} - -func (e *EmptyURLError) Error() string { - return "failed ensuring server, empty url provided" -} - -type CurrentNoMapError struct{} - -func (e *CurrentNoMapError) Error() string { - return "failed getting current server, no servers available" -} - -type CurrentNotFoundError struct{} - -func (e *CurrentNotFoundError) Error() string { - return "failed getting current server, not found" -} diff --git a/internal/server/servers.go b/internal/server/servers.go index a076770..b34dcff 100644 --- a/internal/server/servers.go +++ b/internal/server/servers.go @@ -1,9 +1,8 @@ package server import ( - "fmt" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) type Servers struct { @@ -14,125 +13,105 @@ type Servers struct { IsType Type `json:"is_secure_internet"` } -func (servers *Servers) AddSecureInternet( +func (ss *Servers) AddSecureInternet( secureOrg *types.DiscoveryOrganization, secureServer *types.DiscoveryServer, ) (Server, error) { - errorMessage := "failed adding secure internet server" // If we have specified an organization ID // We also need to get an authorization template - initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer) + err := ss.SecureInternetHomeServer.init(secureOrg, secureServer) - if initErr != nil { - return nil, types.NewWrappedError(errorMessage, initErr) + if err != nil { + return nil, err } - servers.IsType = SecureInternetServerType - return &servers.SecureInternetHomeServer, nil + ss.IsType = SecureInternetServerType + return &ss.SecureInternetHomeServer, nil } -func (servers *Servers) GetCurrentServer() (Server, error) { - errorMessage := "failed getting current server" - if servers.IsType == SecureInternetServerType { - if !servers.HasSecureLocation() { - return nil, types.NewWrappedError( - errorMessage, - &CurrentNotFoundError{}, - ) +func (ss *Servers) GetCurrentServer() (Server, error) { + //TODO(jwijenbergh): Almost certainly the return type should be pointer (*Server) + if ss.IsType == SecureInternetServerType { + if !ss.HasSecureLocation() { + return nil, errors.Errorf("ss.IsType = %v; ss.HasSecureLocation() = false", ss.IsType) } - return &servers.SecureInternetHomeServer, nil + return &ss.SecureInternetHomeServer, nil } - serversStruct := &servers.InstituteServers + srvs := &ss.InstituteServers - if servers.IsType == CustomServerType { - serversStruct = &servers.CustomServers + if ss.IsType == CustomServerType { + srvs = &ss.CustomServers } - currentServerURL := serversStruct.CurrentURL - bases := serversStruct.Map - if bases == nil { - return nil, types.NewWrappedError( - errorMessage, - &CurrentNoMapError{}, - ) + bs := srvs.Map + if bs == nil { + return nil, errors.Errorf("srvs.Map is nil") } - server, exists := bases[currentServerURL] - if !exists || server == nil { - return nil, types.NewWrappedError( - errorMessage, - &CurrentNotFoundError{}, - ) + if srv, ok := bs[srvs.CurrentURL]; !ok || srv == nil { + return nil, errors.Errorf("server not found") + } else { + return srv, nil } - return server, nil } -func (servers *Servers) addInstituteAndCustom( +func (ss *Servers) addInstituteAndCustom( discoServer *types.DiscoveryServer, isCustom bool, ) (Server, error) { url := discoServer.BaseURL - errorMessage := fmt.Sprintf("failed adding institute access server: %s", url) - toAddServers := &servers.InstituteServers - serverType := InstituteAccessServerType + srvs := &ss.InstituteServers + srvType := InstituteAccessServerType if isCustom { - toAddServers = &servers.CustomServers - serverType = CustomServerType + srvs = &ss.CustomServers + srvType = CustomServerType } - if toAddServers.Map == nil { - toAddServers.Map = make(map[string]*InstituteAccessServer) + if srvs.Map == nil { + srvs.Map = make(map[string]*InstituteAccessServer) } - server, exists := toAddServers.Map[url] + srv, ok := srvs.Map[url] // initialize the server if it doesn't exist yet - if !exists { - server = &InstituteAccessServer{} + if !ok { + srv = &InstituteAccessServer{} } - instituteInitErr := server.init( - url, - discoServer.DisplayName, - discoServer.Type, - discoServer.SupportContact, - ) - if instituteInitErr != nil { - return nil, types.NewWrappedError(errorMessage, instituteInitErr) + if err := srv.init(url, discoServer.DisplayName, discoServer.Type, discoServer.SupportContact); err != nil { + return nil, err } - toAddServers.Map[url] = server - servers.IsType = serverType - return server, nil + srvs.Map[url] = srv + ss.IsType = srvType + return srv, nil } -func (servers *Servers) AddInstituteAccessServer( +func (ss *Servers) AddInstituteAccessServer( instituteServer *types.DiscoveryServer, ) (Server, error) { - return servers.addInstituteAndCustom(instituteServer, false) + return ss.addInstituteAndCustom(instituteServer, false) } -func (servers *Servers) AddCustomServer( +func (ss *Servers) AddCustomServer( customServer *types.DiscoveryServer, ) (Server, error) { - return servers.addInstituteAndCustom(customServer, true) + return ss.addInstituteAndCustom(customServer, true) } -func (servers *Servers) GetSecureLocation() string { - return servers.SecureInternetHomeServer.CurrentLocation +func (ss *Servers) GetSecureLocation() string { + return ss.SecureInternetHomeServer.CurrentLocation } -func (servers *Servers) SetSecureLocation( +func (ss *Servers) SetSecureLocation( chosenLocationServer *types.DiscoveryServer, ) error { - errorMessage := "failed to set secure location" // Make sure to add the current location - _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer) - if addLocationErr != nil { - return types.NewWrappedError(errorMessage, addLocationErr) + if _, err := ss.SecureInternetHomeServer.addLocation(chosenLocationServer); err != nil { + return err } - servers.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode + ss.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode return nil } -- cgit v1.2.3