From 2252135fadb8c579ad27345e3203be755130e3cd Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Mon, 20 Jun 2022 15:20:18 +0200 Subject: Refactor: Errors to have one custom type that is to be wrapped - For this an `internal/types` package is created with a custom error type - This custom error type can give back the cause and traceback of an error --- internal/server/api.go | 78 +++++++----------------- internal/server/server.go | 152 ++++++++++++++-------------------------------- 2 files changed, 69 insertions(+), 161 deletions(-) (limited to 'internal/server') diff --git a/internal/server/api.go b/internal/server/api.go index 96bd641..c8c7180 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -6,32 +6,34 @@ import ( "fmt" "net/http" "net/url" + httpw "github.com/jwijenbergh/eduvpn-common/internal/http" "github.com/jwijenbergh/eduvpn-common/internal/log" + "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" ) func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) { + errorMessage := "failed getting server endpoints" url := fmt.Sprintf("%s/%s", baseURL, WellKnownPath) _, body, bodyErr := httpw.HTTPGet(url) if bodyErr != nil { - return nil, &APIGetEndpointsError{Err: bodyErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} } endpoints := &ServerEndpoints{} jsonErr := json.Unmarshal(body, endpoints) if jsonErr != nil { - return nil, &APIGetEndpointsError{Err: jsonErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} } return endpoints, nil } -// Authorized wrappers on top of HTTP -// the errors will not be wrapped here so that the caller can check if we got a status error, to retry oauth func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HTTPOptionalParams) (http.Header, []byte, error) { + errorMessage := "failed API authorized" // Ensure optional is not nil as we will fill it with headers if opts == nil { opts = &httpw.HTTPOptionalParams{} @@ -39,7 +41,7 @@ func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HT base, baseErr := server.GetBase() if baseErr != nil { - return nil, nil, baseErr + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } url := base.Endpoints.API.V3.API + endpoint @@ -52,7 +54,7 @@ func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HT base.FSM.Current = stateBefore if oauthErr != nil { - return nil, nil, oauthErr + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr} } headerKey := "Authorization" @@ -66,11 +68,12 @@ func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HT } func apiAuthorizedRetry(server Server, method string, endpoint string, opts *httpw.HTTPOptionalParams) (http.Header, []byte, error) { + errorMessage := "failed authorized API retry" header, body, bodyErr := apiAuthorized(server, method, endpoint, opts) base, baseErr := server.GetBase() if baseErr != nil { - return nil, nil, &APIAuthorizedError{Err: baseErr} + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } if bodyErr != nil { var error *httpw.HTTPStatusError @@ -82,31 +85,32 @@ func apiAuthorizedRetry(server Server, method string, endpoint string, opts *htt server.GetOAuth().Token.ExpiredTimestamp = util.GenerateTimeSeconds() retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts) if retryErr != nil { - return nil, nil, &APIAuthorizedError{Err: retryErr} + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: retryErr} } return retryHeader, retryBody, nil } - return nil, nil, &APIAuthorizedError{Err: bodyErr} + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} } return header, body, nil } func APIInfo(server Server) error { + errorMessage := "failed API /info" _, body, bodyErr := apiAuthorizedRetry(server, http.MethodGet, "/info", nil) if bodyErr != nil { - return &APIInfoError{Err: bodyErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} } structure := ServerProfileInfo{} jsonErr := json.Unmarshal(body, &structure) if jsonErr != nil { - return &APIInfoError{Err: jsonErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} } base, baseErr := server.GetBase() if baseErr != nil { - return &APIInfoError{Err: baseErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } // Store the profiles and make sure that the current profile is not overwritten @@ -118,6 +122,7 @@ func APIInfo(server Server) error { } func APIConnectWireguard(server Server, profile_id string, pubkey string, supportsOpenVPN bool) (string, string, int64, error) { + errorMessage := "failed obtaining a WireGuard configuration" headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-wireguard-profile"}, @@ -133,14 +138,14 @@ func APIConnectWireguard(server Server, profile_id string, pubkey string, suppor } header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { - return "", "", 0, &APIConnectWireguardError{Err: connectErr} + return "", "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr} } expires := header.Get("expires") pTime, pTimeErr := http.ParseTime(expires) if pTimeErr != nil { - return "", "", 0, &APIConnectWireguardError{Err: pTimeErr} + return "", "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr} } contentType := header.Get("content-type") @@ -153,6 +158,7 @@ func APIConnectWireguard(server Server, profile_id string, pubkey string, suppor } func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error) { + errorMessage := "failed obtaining an OpenVPN configuration" headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-openvpn-profile"}, @@ -164,13 +170,13 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error) header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { - return "", 0, &APIConnectOpenVPNError{Err: connectErr} + return "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr} } expires := header.Get("expires") pTime, pTimeErr := http.ParseTime(expires) if pTimeErr != nil { - return "", 0, &APIConnectOpenVPNError{Err: pTimeErr} + return "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr} } return string(connectBody), pTime.Unix(), nil } @@ -179,43 +185,3 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error) func APIDisconnect(server Server) { apiAuthorizedRetry(server, http.MethodPost, "/disconnect", nil) } - -type APIAuthorizedError struct { - Err error -} - -func (e *APIAuthorizedError) Error() string { - return fmt.Sprintf("failed api authorized call with error: %v", e.Err) -} - -type APIConnectWireguardError struct { - Err error -} - -func (e *APIConnectWireguardError) Error() string { - return fmt.Sprintf("failed api /connect wireguard call with error: %v", e.Err) -} - -type APIConnectOpenVPNError struct { - Err error -} - -func (e *APIConnectOpenVPNError) Error() string { - return fmt.Sprintf("failed api /connect OpenVPN call with error: %v", e.Err) -} - -type APIInfoError struct { - Err error -} - -func (e *APIInfoError) Error() string { - return fmt.Sprintf("failed api /info call with error: %v", e.Err) -} - -type APIGetEndpointsError struct { - Err error -} - -func (e *APIGetEndpointsError) Error() string { - return fmt.Sprintf("failed to get server endpoint with error %v", e.Err) -} diff --git a/internal/server/server.go b/internal/server/server.go index a1fb749..ce72400 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,9 +2,11 @@ package server import ( "fmt" + "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/log" "github.com/jwijenbergh/eduvpn-common/internal/oauth" + "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" "github.com/jwijenbergh/eduvpn-common/internal/wireguard" ) @@ -17,8 +19,8 @@ type ServerBase struct { ProfilesRaw string `json:"profiles_raw"` StartTime int64 `json:"start-time"` EndTime int64 `json:"end-time"` - Logger *log.FileLogger `json:"-"` - FSM *fsm.FSM `json:"-"` + Logger *log.FileLogger `json:"-"` + FSM *fsm.FSM `json:"-"` } // An instute access server @@ -49,18 +51,19 @@ type InstituteServers struct { } func (servers *Servers) GetCurrentServer() (Server, error) { + errorMessage := "failed getting current server" if servers.IsSecureInternet { return &servers.SecureInternetHomeServer, nil } currentInstitute := servers.InstituteServers.CurrentURL institutes := servers.InstituteServers.Map if institutes == nil { - return nil, &ServerGetCurrentNoMapError{} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetCurrentNoMapError{}} } institute, exists := institutes[currentInstitute] if !exists || institute == nil { - return nil, &ServerGetCurrentNotFoundError{} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetCurrentNotFoundError{}} } return institute, nil } @@ -96,25 +99,27 @@ func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) { } func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) { + errorMessage := "failed getting current secure internet home base" if server.BaseMap == nil { - return nil, &ServerSecureInternetMapNotFoundError{} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetMapNotFoundError{}} } base, exists := server.BaseMap[server.CurrentURL] if !exists { - return nil, &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL}} } return base, nil } func (institute *InstituteAccessServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error { + errorMessage := fmt.Sprintf("failed initializing institute server %s", url) institute.Base.URL = url institute.Base.FSM = fsm institute.Base.Logger = logger endpoints, endpointsErr := APIGetEndpoints(url) if endpointsErr != nil { - return &ServerInitializeError{URL: url, Err: endpointsErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} } institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token, fsm, logger) institute.Base.Endpoints = *endpoints @@ -122,6 +127,7 @@ func (institute *InstituteAccessServer) init(url string, fsm *fsm.FSM, logger *l } func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error { + errorMessage := fmt.Sprintf("failed initializing secure internet home server %s", url) // Initialize the base map if it is non-nil if secure.BaseMap == nil { secure.BaseMap = make(map[string]*ServerBase) @@ -136,7 +142,7 @@ func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *l base.URL = url endpoints, endpointsErr := APIGetEndpoints(url) if endpointsErr != nil { - return &ServerInitializeError{URL: url, Err: endpointsErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} } base.Endpoints = *endpoints } @@ -166,7 +172,7 @@ func ShouldRenewButton(server Server) (bool, error) { base, baseErr := server.GetBase() if baseErr != nil { - //return false, &GetRenewButtonTimeError{Err: baseErr} + // return false, &GetRenewButtonTimeError{Err: baseErr} return false, nil } @@ -186,7 +192,7 @@ func ShouldRenewButton(server Server) (bool, error) { // Session duration is less than 24 hours but not 75% has passed duration := base.EndTime - base.StartTime // TODO: Is converting to float64 okay here? - if duration < 24*60*60 && float64(current) <= (float64(base.StartTime) + 0.75*float64(duration)) { + if duration < 24*60*60 && float64(current) <= (float64(base.StartTime)+0.75*float64(duration)) { return false, nil } @@ -198,17 +204,18 @@ func Login(server Server) error { } func EnsureTokens(server Server) error { + errorMessage := "failed ensuring server tokens" base, baseErr := server.GetBase() if baseErr != nil { - return &ServerEnsureTokensError{Err: baseErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } if server.GetOAuth().NeedsRelogin() { base.Logger.Log(log.LOG_INFO, "OAuth: Tokens are invalid, relogging in") loginErr := Login(server) if loginErr != nil { - return &ServerEnsureTokensError{Err: loginErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr} } } return nil @@ -223,13 +230,14 @@ func CancelOAuth(server Server) { } func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) { + errorMessage := "failed ensuring server" // Intialize the secure internet server // This calls the init method which takes care of the rest if isSecureInternet { initErr := servers.SecureInternetHomeServer.init(url, fsm, logger) if initErr != nil { - return nil, &ServerEnsureServerError{Err: initErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr} } servers.IsSecureInternet = true @@ -253,7 +261,7 @@ func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm instituteServers.CurrentURL = url instituteInitErr := institute.init(url, fsm, logger) if instituteInitErr != nil { - return nil, &ServerEnsureServerError{Err: instituteInitErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: instituteInitErr} } instituteServers.Map[url] = institute servers.IsSecureInternet = false @@ -310,10 +318,11 @@ func (profile *ServerProfile) supportsOpenVPN() bool { } func getCurrentProfile(server Server) (*ServerProfile, error) { + errorMessage := "failed getting current profile" base, baseErr := server.GetBase() if baseErr != nil { - return nil, &ServerGetCurrentProfileError{Err: baseErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } profileID := base.Profiles.Current for _, profile := range base.Profiles.Info.ProfileList { @@ -321,28 +330,30 @@ func getCurrentProfile(server Server) (*ServerProfile, error) { return &profile, nil } } - return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID} + + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetCurrentProfileNotFoundError{ProfileID: profileID}} } func wireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, error) { + errorMessage := "failed getting server WireGuard configuration" base, baseErr := server.GetBase() if baseErr != nil { - return "", "", baseErr + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } profile_id := base.Profiles.Current wireguardKey, wireguardErr := wireguard.GenerateKey() if wireguardErr != nil { - return "", "", wireguardErr + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: wireguardErr} } wireguardPublicKey := wireguardKey.PublicKey().String() config, content, expires, configErr := APIConnectWireguard(server, profile_id, wireguardPublicKey, supportsOpenVPN) if configErr != nil { - return "", "", wireguardErr + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } // Store start and end time @@ -361,10 +372,11 @@ func wireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, er } func openVPNGetConfig(server Server) (string, string, error) { + errorMessage := "failed getting server OpenVPN configuration" base, baseErr := server.GetBase() if baseErr != nil { - return "", "", baseErr + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } profile_id := base.Profiles.Current configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profile_id) @@ -374,25 +386,26 @@ func openVPNGetConfig(server Server) (string, string, error) { base.EndTime = expires if configErr != nil { - return "", "", configErr + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } return configOpenVPN, "openvpn", nil } func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) { + errorMessage := "failed getting an OpenVPN/WireGuard configuration with a profile" base, baseErr := server.GetBase() if baseErr != nil { - return "", "", &ServerGetConfigWithProfileError{Err: baseErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } if !base.FSM.HasTransition(fsm.HAS_CONFIG) { - return "", "", &fsm.FSMWrongStateTransitionError{Got: base.FSM.Current, Want: fsm.HAS_CONFIG} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: base.FSM.Current, Want: fsm.HAS_CONFIG}.CustomError()} } profile, profileErr := getCurrentProfile(server) if profileErr != nil { - return "", "", &ServerGetConfigWithProfileError{Err: profileErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr} } supportsOpenVPN := profile.supportsOpenVPN() @@ -400,7 +413,7 @@ func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) // If forceTCP we must be able to get a config with OpenVPN if forceTCP && supportsOpenVPN { - return "", "", &ServerGetConfigForceTCPError{} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetConfigForceTCPError{}} } var config string @@ -416,40 +429,42 @@ func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) } if configErr != nil { - return "", "", &ServerGetConfigWithProfileError{Err: configErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } return config, configType, nil } func askForProfileID(server Server) error { + errorMessage := "failed asking for a server profile ID" base, baseErr := server.GetBase() if baseErr != nil { - return &ServerAskForProfileIDError{Err: baseErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } if !base.FSM.HasTransition(fsm.ASK_PROFILE) { - return &fsm.FSMWrongStateTransitionError{Got: base.FSM.Current, Want: fsm.ASK_PROFILE} + return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: base.FSM.Current, Want: fsm.ASK_PROFILE}.CustomError()} } base.FSM.GoTransitionWithData(fsm.ASK_PROFILE, base.ProfilesRaw, false) return nil } func GetConfig(server Server, forceTCP bool) (string, string, error) { + errorMessage := "failed getting an OpenVPN/WireGuard configuration" base, baseErr := server.GetBase() if baseErr != nil { - return "", "", &ServerGetConfigError{Err: baseErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } if !base.FSM.InState(fsm.REQUEST_CONFIG) { - return "", "", &fsm.FSMWrongStateError{Got: base.FSM.Current, Want: fsm.REQUEST_CONFIG} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: base.FSM.Current, Want: fsm.REQUEST_CONFIG}.CustomError()} } // Get new profiles using the info call // This does not override the current profile infoErr := APIInfo(server) if infoErr != nil { - return "", "", &ServerGetConfigError{Err: infoErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr} } // If there was a profile chosen and it doesn't exist anymore, reset it @@ -473,7 +488,7 @@ func GetConfig(server Server, forceTCP bool) (string, string, error) { profileErr := askForProfileID(server) if profileErr != nil { - return "", "", &ServerGetConfigError{Err: profileErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr} } return getConfigWithProfile(server, forceTCP) @@ -487,14 +502,6 @@ func (e *ServerGetCurrentProfileNotFoundError) Error() string { return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID) } -type ServerGetConfigWithProfileError struct { - Err error -} - -func (e *ServerGetConfigWithProfileError) Error() string { - return fmt.Sprintf("failed to get config including profile with error %v", e.Err) -} - type ServerGetConfigForceTCPError struct{} func (e *ServerGetConfigForceTCPError) Error() string { @@ -507,28 +514,12 @@ func (e *ServerGetSecureInternetHomeError) Error() string { return "failed to get secure internet home server, not found" } -type ServerCopySecureInternetOAuthError struct { - Err error -} - -func (e *ServerCopySecureInternetOAuthError) Error() string { - return fmt.Sprintf("failed to copy oauth tokens from home server with error %v", e.Err) -} - type ServerEnsureServerEmptyURLError struct{} func (e *ServerEnsureServerEmptyURLError) Error() string { return "failed ensuring server, empty url provided" } -type ServerEnsureServerError struct { - Err error -} - -func (e *ServerEnsureServerError) Error() string { - return fmt.Sprintf("failed ensuring server with error %v", e.Err) -} - type ServerGetCurrentNoMapError struct{} func (e *ServerGetCurrentNoMapError) Error() string { @@ -541,31 +532,6 @@ func (e *ServerGetCurrentNotFoundError) Error() string { return "failed getting current server, not found" } -type ServerGetConfigError struct { - Err error -} - -func (e *ServerGetConfigError) Error() string { - return fmt.Sprintf("failed getting server config with error %v", e.Err) -} - -type ServerInitializeError struct { - URL string - Err error -} - -func (e *ServerInitializeError) Error() string { - return fmt.Sprintf("failed initializing server with url %s and error %v", e.URL, e.Err) -} - -type ServerInstituteBaseNotFoundError struct { - Err error -} - -func (e *ServerInstituteBaseNotFoundError) Error() string { - return "institute base not found" -} - type ServerSecureInternetMapNotFoundError struct{} func (e *ServerSecureInternetMapNotFoundError) Error() string { @@ -579,27 +545,3 @@ type ServerSecureInternetBaseNotFoundError struct { func (e *ServerSecureInternetBaseNotFoundError) Error() string { return fmt.Sprintf("secure internet base not found with current: %s", e.Current) } - -type ServerGetCurrentProfileError struct { - Err error -} - -func (e *ServerGetCurrentProfileError) Error() string { - return fmt.Sprintf("failed getting current profile with error: %v", e.Err) -} - -type ServerAskForProfileIDError struct { - Err error -} - -func (e *ServerAskForProfileIDError) Error() string { - return fmt.Sprintf("ask for profile ID error: %v", e.Err) -} - -type ServerEnsureTokensError struct { - Err error -} - -func (e *ServerEnsureTokensError) Error() string { - return fmt.Sprintf("failed ensuring tokens with error: %v", e.Err) -} -- cgit v1.2.3