diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-05-02 14:34:35 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-05-02 14:34:35 +0200 |
| commit | 466450f0c47bdc614e66326d90e5fc6fb56ae732 (patch) | |
| tree | a01518a58d50d2f8449d37dadecc40e35c9f1fe1 | |
| parent | a2a8efdcaad3d9b1852b1367a7cd7e8c5860cecf (diff) | |
Refactor: Wrap most errors in a custom type
| -rw-r--r-- | internal/api.go | 53 | ||||
| -rw-r--r-- | internal/config.go | 22 | ||||
| -rw-r--r-- | internal/fsm.go | 18 | ||||
| -rw-r--r-- | internal/http.go | 109 | ||||
| -rw-r--r-- | internal/log.go | 14 | ||||
| -rw-r--r-- | internal/oauth.go | 113 | ||||
| -rw-r--r-- | internal/openvpn.go | 12 | ||||
| -rw-r--r-- | internal/server.go | 115 | ||||
| -rw-r--r-- | internal/verify.go | 154 | ||||
| -rw-r--r-- | internal/wireguard.go | 28 | ||||
| -rw-r--r-- | state.go | 76 | ||||
| -rw-r--r-- | state_test.go | 36 |
12 files changed, 531 insertions, 219 deletions
diff --git a/internal/api.go b/internal/api.go index 2ed605e..da17f76 100644 --- a/internal/api.go +++ b/internal/api.go @@ -9,6 +9,7 @@ import ( ) // Authorized wrappers on top of HTTP +// the errors will not be wrapped here so that the caller can check if we got a status error, to retry oauth func (server *Server) apiAuthorized(method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { // Ensure optional is not nil as we will fill it with headers if opts == nil { @@ -38,28 +39,32 @@ func (server *Server) apiAuthorizedRetry(method string, endpoint string, opts *H if bodyErr != nil { var error *HTTPStatusError - // Only retry authroized if we get a HTTP 401 + // Only retry authorized if we get a HTTP 401 if errors.As(bodyErr, &error) && error.Status == 401 { server.Logger.Log(LOG_INFO, fmt.Sprintf("API: Got HTTP error %v, retrying authorized", error)) // Tell the method that the token is expired server.OAuth.Token.ExpiredTimestamp = GenerateTimeSeconds() - return server.apiAuthorized(method, endpoint, opts) + retryHeader, retryBody, retryErr := server.apiAuthorized(method, endpoint, opts) + if retryErr != nil { + return nil, nil, &APIAuthorizedError{Err: retryErr} + } + return retryHeader, retryBody, nil } - return header, nil, bodyErr + return nil, nil, &APIAuthorizedError{Err: bodyErr} } - return header, body, bodyErr + return header, body, nil } func (server *Server) APIInfo() error { _, body, bodyErr := server.apiAuthorizedRetry(http.MethodGet, "/info", nil) if bodyErr != nil { - return bodyErr + return &APIInfoError{Err: bodyErr} } structure := ServerProfileInfo{} jsonErr := json.Unmarshal(body, &structure) if jsonErr != nil { - return jsonErr + return &APIInfoError{Err: jsonErr} } server.Profiles = structure @@ -79,7 +84,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str } header, connectBody, connectErr := server.apiAuthorizedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { - return "", "", connectErr + return "", "", &APIConnectWireguardError{Err: connectErr} } expires := header.Get("expires") @@ -97,7 +102,7 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro } header, connectBody, connectErr := server.apiAuthorizedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { - return "", "", connectErr + return "", "", &APIConnectOpenVPNError{Err: connectErr} } expires := header.Get("expires") @@ -108,3 +113,35 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro func (server *Server) APIDisconnect() { server.apiAuthorizedRetry(http.MethodPost, "/disconnect", nil) } + +type APIAuthorizedError struct { + Err error +} + +func (e *APIAuthorizedError) Error() string { + return fmt.Sprintf("failed api authorized call with error: %v", e.Err) +} + +type APIConnectWireguardError struct { + Err error +} + +func (e *APIConnectWireguardError) Error() string { + return fmt.Sprintf("failed api /connect wireguard call with error: %v", e.Err) +} + +type APIConnectOpenVPNError struct { + Err error +} + +func (e *APIConnectOpenVPNError) Error() string { + return fmt.Sprintf("failed api /connect OpenVPN call with error: %v", e.Err) +} + +type APIInfoError struct { + Err error +} + +func (e *APIInfoError) Error() string { + return fmt.Sprintf("failed api /info call with error: %v", e.Err) +} diff --git a/internal/config.go b/internal/config.go index 47f773e..a135ac6 100644 --- a/internal/config.go +++ b/internal/config.go @@ -25,11 +25,11 @@ func (config *Config) GetFilename() string { func (config *Config) Save(readStruct interface{}) error { configDirErr := EnsureDirectory(config.Directory) if configDirErr != nil { - return configDirErr + return &ConfigSaveError{Err: configDirErr} } jsonString, marshalErr := json.Marshal(readStruct) if marshalErr != nil { - return marshalErr + return &ConfigSaveError{Err: marshalErr} } return ioutil.WriteFile(config.GetFilename(), jsonString, 0o644) } @@ -37,7 +37,23 @@ func (config *Config) Save(readStruct interface{}) error { func (config *Config) Load(writeStruct interface{}) error { bytes, readErr := ioutil.ReadFile(config.GetFilename()) if readErr != nil { - return readErr + return &ConfigLoadError{Err: readErr} } return json.Unmarshal(bytes, writeStruct) } + +type ConfigSaveError struct { + Err error +} + +func (e *ConfigSaveError) Error() string { + return fmt.Sprintf("failed to save config with error: %v", e.Err) +} + +type ConfigLoadError struct { + Err error +} + +func (e *ConfigLoadError) Error() string { + return fmt.Sprintf("failed to load config with error: %v", e.Err) +} diff --git a/internal/fsm.go b/internal/fsm.go index 1bcc479..0b9ad1e 100644 --- a/internal/fsm.go +++ b/internal/fsm.go @@ -206,3 +206,21 @@ func (fsm *FSM) generateMermaidGraph() string { func (fsm *FSM) GenerateGraph() string { return fsm.generateMermaidGraph() } + +type FSMWrongStateTransitionError struct { + Got FSMStateID + Want FSMStateID +} + +func (e *FSMWrongStateTransitionError) Error() string { + return fmt.Sprintf("wrong FSM state, got: %s, want a state with a transition to: %s", e.Got.String(), e.Want.String()) +} + +type FSMWrongStateError struct { + Got FSMStateID + Want FSMStateID +} + +func (e *FSMWrongStateError) Error() string { + return fmt.Sprintf("wrong FSM state, got: %s, want: %s", e.Got.String(), e.Want.String()) +} diff --git a/internal/http.go b/internal/http.go index 8ca8cb9..0b1eda4 100644 --- a/internal/http.go +++ b/internal/http.go @@ -9,52 +9,6 @@ import ( "strings" ) -type HTTPResourceError struct { - URL string - Err error -} - -func (e *HTTPResourceError) Error() string { - return fmt.Sprintf("failed obtaining HTTP resource %s with error %v", e.URL, e.Err) -} - -type HTTPStatusError struct { - URL string - Status int -} - -func (e *HTTPStatusError) Error() string { - return fmt.Sprintf("failed obtaining HTTP resource %s as it gave an unsuccesful status code %d", e.URL, e.Status) -} - -type HTTPReadError struct { - URL string - Err error -} - -func (e *HTTPReadError) Error() string { - return fmt.Sprintf("failed reading HTTP resource %s with error %v", e.URL, e.Err) -} - -type HTTPParseJsonError struct { - URL string - Body string - Err error -} - -func (e *HTTPParseJsonError) Error() string { - return fmt.Sprintf("failed parsing json %s for HTTP resource %s with error %v", e.Body, e.URL, e.Err) -} - -type HTTPRequestCreateError struct { - URL string - Err error -} - -func (e *HTTPRequestCreateError) Error() string { - return fmt.Sprintf("failed to create HTTP request with url %s and error %v", e.URL, e.Err) -} - type URLParameters map[string]string type HTTPOptionalParams struct { @@ -65,9 +19,9 @@ type HTTPOptionalParams struct { // Construct an URL including on parameters func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) { - url, err := url.Parse(baseURL) - if err != nil { - return "", err + url, parseErr := url.Parse(baseURL) + if parseErr != nil { + return "", &HTTPConstructURLError{URL: baseURL, Parameters: parameters, Err: parseErr} } q := url.Query() @@ -130,6 +84,7 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht // it already has the right error so so we don't wrap it further url, urlErr := httpOptionalURL(url, opts) if urlErr != nil { + // No further type wrapping is needed here return nil, nil, urlErr } @@ -170,3 +125,59 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht // Return the body in bytes and signal the status error if there was one return resp.Header, body, nil } + +type HTTPResourceError struct { + URL string + Err error +} + +func (e *HTTPResourceError) Error() string { + return fmt.Sprintf("failed obtaining HTTP resource: %s with error: %v", e.URL, e.Err) +} + +type HTTPStatusError struct { + URL string + Status int +} + +func (e *HTTPStatusError) Error() string { + return fmt.Sprintf("failed obtaining HTTP resource: %s as it gave an unsuccesful status code: %d", e.URL, e.Status) +} + +type HTTPReadError struct { + URL string + Err error +} + +func (e *HTTPReadError) Error() string { + return fmt.Sprintf("failed reading HTTP resource: %s with error: %v", e.URL, e.Err) +} + +type HTTPParseJsonError struct { + URL string + Body string + Err error +} + +func (e *HTTPParseJsonError) Error() string { + return fmt.Sprintf("failed parsing json %s for HTTP resource: %s with error: %v", e.Body, e.URL, e.Err) +} + +type HTTPRequestCreateError struct { + URL string + Err error +} + +func (e *HTTPRequestCreateError) Error() string { + return fmt.Sprintf("failed to create HTTP request with url: %s and error: %v", e.URL, e.Err) +} + +type HTTPConstructURLError struct { + URL string + Parameters URLParameters + Err error +} + +func (e *HTTPConstructURLError) Error() string { + return fmt.Sprintf("failed to construct url: %s including parameters: %v with error: %v", e.URL, e.Parameters, e.Err) +} diff --git a/internal/log.go b/internal/log.go index 9248fc0..5109ba2 100644 --- a/internal/log.go +++ b/internal/log.go @@ -39,11 +39,11 @@ func (e LogLevel) String() string { func (logger *FileLogger) Init(level LogLevel, name string, directory string) error { configDirErr := EnsureDirectory(directory) if configDirErr != nil { - return configDirErr + return &LogInitializeError{Name: name, Directory: directory, Err: configDirErr} } logFile, logOpenErr := os.OpenFile(logger.getFilename(directory, name), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o666) if logOpenErr != nil { - return logOpenErr + return &LogInitializeError{Name: name, Directory: directory, Err: logOpenErr} } log.SetOutput(logFile) logger.File = logFile @@ -65,3 +65,13 @@ func (logger *FileLogger) Log(level LogLevel, str string) { func (logger *FileLogger) Close() { logger.File.Close() } + +type LogInitializeError struct { + Name string + Directory string + Err error +} + +func (e *LogInitializeError) Error() string { + return fmt.Sprintf("failed initializing logging with name: %s and directory: %s with error: %v", e.Name, e.Directory, e.Err) +} diff --git a/internal/oauth.go b/internal/oauth.go index 98af5a4..c13ea99 100644 --- a/internal/oauth.go +++ b/internal/oauth.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -20,7 +19,7 @@ import ( func genState() (string, error) { randomBytes, err := MakeRandomByteSlice(32) if err != nil { - return "", &OAuthGenStateUnableError{Err: err} + return "", &OAuthGenStateError{Err: err} } // For consistency we also use raw url encoding here @@ -46,7 +45,7 @@ func genChallengeS256(verifier string) string { func genVerifier() (string, error) { randomBytes, err := MakeRandomByteSlice(32) if err != nil { - return "", &OAuthGenVerifierUnableError{Err: err} + return "", &OAuthGenVerifierError{Err: err} } return base64.RawURLEncoding.EncodeToString(randomBytes), nil @@ -95,7 +94,7 @@ func (oauth *OAuth) getTokensWithCallback() error { } mux.HandleFunc("/callback", oauth.Callback) if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed { - return &OAuthFailedCallbackError{Addr: addr, Err: err} + return &OAuthCallbackError{Addr: addr, Err: err} } return oauth.Session.CallbackError } @@ -122,7 +121,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { current_time := GenerateTimeSeconds() _, body, bodyErr := HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return bodyErr + return &OAuthAuthError{Err: bodyErr} } tokenStructure := OAuthToken{} @@ -160,7 +159,7 @@ func (oauth *OAuth) getTokensWithRefresh() error { current_time := GenerateTimeSeconds() _, body, bodyErr := HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return bodyErr + return &OAuthRefreshError{Err: bodyErr} } tokenStructure := OAuthToken{} @@ -181,7 +180,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // Extract the authorization code code, success := req.URL.Query()["code"] if !success { - oauth.Session.CallbackError = &OAuthFailedCallbackParameterError{Parameter: "code", URL: req.URL.String()} + oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()} go oauth.Session.Server.Shutdown(oauth.Session.Context) return } @@ -192,14 +191,14 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 state, success := req.URL.Query()["state"] if !success { - oauth.Session.CallbackError = &OAuthFailedCallbackParameterError{Parameter: "state", URL: req.URL.String()} + oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()} go oauth.Session.Server.Shutdown(oauth.Session.Context) return } // The state is the first entry extractedState := state[0] if extractedState != oauth.Session.State { - oauth.Session.CallbackError = &OAuthFailedCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State} + oauth.Session.CallbackError = &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State} go oauth.Session.Server.Shutdown(oauth.Session.Context) return } @@ -208,7 +207,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // Obtaining the access and refresh tokens err := oauth.getTokensWithAuthCode(extractedCode) if err != nil { - oauth.Session.CallbackError = &OAuthFailedCallbackGetTokensError{Err: err} + oauth.Session.CallbackError = &OAuthCallbackGetTokensError{Err: err} go oauth.Session.Server.Shutdown(oauth.Session.Context) return } @@ -225,18 +224,18 @@ func (oauth *OAuth) Init(fsm *FSM, logger *FileLogger) { // Starts the OAuth exchange for eduvpn. func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) error { if !oauth.FSM.HasTransition(OAUTH_STARTED) { - return errors.New(fmt.Sprintf("Failed starting oauth, invalid state %s", oauth.FSM.Current.String())) + return &FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: OAUTH_STARTED} } // Generate the state state, stateErr := genState() if stateErr != nil { - return &OAuthFailedInitializeError{Err: stateErr} + return &OAuthInitializeError{Err: stateErr} } // Generate the verifier and challenge verifier, verifierErr := genVerifier() if verifierErr != nil { - return &OAuthFailedInitializeError{Err: verifierErr} + return &OAuthInitializeError{Err: verifierErr} } challenge := genChallengeS256(verifier) @@ -252,8 +251,8 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) authURL, urlErr := HTTPConstructURL(authorizationURL, parameters) - if urlErr != nil { // shouldn't happen - panic(urlErr) + if urlErr != nil { + return &OAuthInitializeError{Err: urlErr} } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client @@ -268,12 +267,12 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) // Error definitions func (oauth *OAuth) Finish() error { if !oauth.FSM.HasTransition(AUTHORIZED) { - return errors.New("invalid state to finish oauth") + return &FSMWrongStateError{Got: oauth.FSM.Current, Want: AUTHORIZED} } tokenErr := oauth.getTokensWithCallback() if tokenErr != nil { - return tokenErr + return &OAuthFinishError{Err: tokenErr} } oauth.FSM.GoTransition(AUTHORIZED) return nil @@ -288,13 +287,13 @@ func (oauth *OAuth) Login(name string, authorizationURL string, tokenURL string) authInitializeErr := oauth.start(name, authorizationURL, tokenURL) if authInitializeErr != nil { - return authInitializeErr + return &OAuthLoginError{Err: authInitializeErr} } oauthErr := oauth.Finish() if oauthErr != nil { - return oauthErr + return &OAuthLoginError{Err: oauthErr} } return nil } @@ -329,64 +328,96 @@ func (oauth *OAuth) NeedsRelogin() bool { type OAuthCancelledCallbackError struct{} func (e *OAuthCancelledCallbackError) Error() string { - return fmt.Sprintf("Client cancelled OAuth") + return fmt.Sprintf("client cancelled OAuth") } -type OAuthGenStateUnableError struct { +type OAuthGenStateError struct { Err error } -func (e *OAuthGenStateUnableError) Error() string { - return fmt.Sprintf("failed generating state with error %v", e.Err) +func (e *OAuthGenStateError) Error() string { + return fmt.Sprintf("failed generating state with error: %v", e.Err) } -type OAuthGenVerifierUnableError struct { +type OAuthGenVerifierError struct { Err error } -func (e *OAuthGenVerifierUnableError) Error() string { - return fmt.Sprintf("failed generating verifier with error %v", e.Err) +func (e *OAuthGenVerifierError) Error() string { + return fmt.Sprintf("failed generating verifier with error: %v", e.Err) } -type OAuthFailedCallbackError struct { +type OAuthCallbackError struct { Addr string Err error } -func (e *OAuthFailedCallbackError) Error() string { - return fmt.Sprintf("failed callback %s with error %v", e.Addr, e.Err) +func (e *OAuthCallbackError) Error() string { + return fmt.Sprintf("failed callback: %s with error: %v", e.Addr, e.Err) } -type OAuthFailedCallbackParameterError struct { +type OAuthCallbackParameterError struct { Parameter string URL string } -func (e *OAuthFailedCallbackParameterError) Error() string { - return fmt.Sprintf("failed retrieving parameter %s in url %s", e.Parameter, e.URL) +func (e *OAuthCallbackParameterError) Error() string { + return fmt.Sprintf("failed retrieving parameter: %s in url: %s", e.Parameter, e.URL) } -type OAuthFailedCallbackStateMatchError struct { +type OAuthCallbackStateMatchError struct { State string ExpectedState string } -func (e *OAuthFailedCallbackStateMatchError) Error() string { - return fmt.Sprintf("failed matching state, got %s, want %s", e.State, e.ExpectedState) +func (e *OAuthCallbackStateMatchError) Error() string { + return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState) } -type OAuthFailedCallbackGetTokensError struct { +type OAuthCallbackGetTokensError struct { Err error } -func (e *OAuthFailedCallbackGetTokensError) Error() string { - return fmt.Sprintf("failed getting tokens with error %v", e.Err) +func (e *OAuthCallbackGetTokensError) Error() string { + return fmt.Sprintf("failed getting tokens with error: %v", e.Err) } -type OAuthFailedInitializeError struct { +type OAuthFinishError struct { Err error } -func (e *OAuthFailedInitializeError) Error() string { - return fmt.Sprintf("failed initializing OAuth with error %v", e.Err) +func (e *OAuthFinishError) Error() string { + return fmt.Sprintf("failed finishing OAuth with error: %v", e.Err) +} + +type OAuthLoginError struct { + Err error +} + +func (e *OAuthLoginError) Error() string { + return fmt.Sprintf("failed OAuth logging in with error: %v", e.Err) +} + +type OAuthInitializeError struct { + Err error +} + +func (e *OAuthInitializeError) Error() string { + return fmt.Sprintf("failed initializing OAuth with error: %v", e.Err) +} + +type OAuthAuthError struct { + Err error +} + +func (e *OAuthAuthError) Error() string { + return fmt.Sprintf("failed getting tokens with auth code for OAuth with error: %v", e.Err) +} + +type OAuthRefreshError struct { + Err error +} + +func (e *OAuthRefreshError) Error() string { + return fmt.Sprintf("failed refreshing tokens for OAuth with error: %v", e.Err) } diff --git a/internal/openvpn.go b/internal/openvpn.go index 1b2e626..ed31fe2 100644 --- a/internal/openvpn.go +++ b/internal/openvpn.go @@ -1,12 +1,22 @@ package internal +import "fmt" + func (server *Server) OpenVPNGetConfig() (string, error) { profile_id := server.Profiles.Current configOpenVPN, _, configErr := server.APIConnectOpenVPN(profile_id) if configErr != nil { - return "", configErr + return "", &OpenVPNGetConfigError{Err: configErr} } return configOpenVPN, nil } + +type OpenVPNGetConfigError struct { + Err error +} + +func (e *OpenVPNGetConfigError) Error() string { + return fmt.Sprintf("failed getting OpenVPN config with error: %v", e.Err) +} diff --git a/internal/server.go b/internal/server.go index 1d6f1e1..489719e 100644 --- a/internal/server.go +++ b/internal/server.go @@ -24,12 +24,12 @@ type Servers struct { func (servers *Servers) GetCurrentServer() (*Server, error) { if servers.List == nil { - return nil, errors.New("No map found to get Current Server") + return nil, &ServerGetCurrentNoMapError{} } server, exists := servers.List[servers.Current] if !exists || server == nil { - return nil, errors.New("Current Server not found") + return nil, &ServerGetCurrentNotFoundError{} } return server, nil } @@ -45,7 +45,7 @@ func (server *Server) Init(url string, fsm *FSM, logger *FileLogger) error { server.OAuth.Init(fsm, logger) endpointsErr := server.GetEndpoints() if endpointsErr != nil { - return endpointsErr + return &ServerInitializeError{URL: url, Err: endpointsErr} } return nil } @@ -60,7 +60,7 @@ func (server *Server) EnsureTokens() error { func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger, makeCurrent bool) (*Server, error) { if url == "" { - return nil, errors.New("Emtpy URL to ensure Server") + return nil, &ServerEnsureServerEmptyURLError{} } if servers.List == nil { servers.List = make(map[string]*Server) @@ -74,7 +74,7 @@ func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger, m serverInitErr := server.Init(url, fsm, logger) if serverInitErr != nil { - return nil, serverInitErr + return nil, &ServerEnsureServerError{Err: serverInitErr} } servers.List[url] = server @@ -88,7 +88,7 @@ func (servers *Servers) getSecureInternetHome() (*Server, error) { server, exists := servers.List[servers.SecureHome] if !exists || server == nil { - return nil, errors.New("No secure internet home found") + return nil, &ServerGetSecureInternetHomeError{} } return server, nil @@ -104,7 +104,7 @@ func (servers *Servers) CopySecureInternetOAuth(server *Server) error { secureHome, secureHomeErr := servers.getSecureInternetHome() if secureHomeErr != nil { - return secureHomeErr + return &ServerCopySecureInternetOAuthError{Err: secureHomeErr} } // Forward token properties @@ -155,14 +155,14 @@ func (server *Server) GetEndpoints() error { _, body, bodyErr := HTTPGet(url) if bodyErr != nil { - return bodyErr + return &ServerGetEndpointsError{Err: bodyErr} } endpoints := ServerEndpoints{} jsonErr := json.Unmarshal(body, &endpoints) if jsonErr != nil { - return jsonErr + return &ServerGetEndpointsError{Err: jsonErr} } server.Endpoints = endpoints @@ -180,23 +180,23 @@ func (profile *ServerProfile) supportsWireguard() bool { } func (server *Server) getCurrentProfile() (*ServerProfile, error) { - profile_id := server.Profiles.Current + profileID := server.Profiles.Current for _, profile := range server.Profiles.Info.ProfileList { - if profile.ID == profile_id { + if profile.ID == profileID { return &profile, nil } } - return nil, errors.New(fmt.Sprintf("no profile found for id %s", profile_id)) + return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID} } func (server *Server) getConfigWithProfile() (string, error) { if !server.FSM.HasTransition(HAS_CONFIG) { - return "", errors.New("cannot get a config with a profile, invalid state") + return "", &FSMWrongStateTransitionError{Got: server.FSM.Current, Want: HAS_CONFIG} } profile, profileErr := server.getCurrentProfile() if profileErr != nil { - return "", profileErr + return "", &ServerGetConfigWithProfileError{Err: profileErr} } if profile.supportsWireguard() { @@ -207,7 +207,7 @@ func (server *Server) getConfigWithProfile() (string, error) { func (server *Server) askForProfileID() error { if !server.FSM.HasTransition(ASK_PROFILE) { - return errors.New("cannot ask for a profile id, invalid state") + return &FSMWrongStateTransitionError{Got: server.FSM.Current, Want: ASK_PROFILE} } server.FSM.GoTransitionWithData(ASK_PROFILE, server.ProfilesRaw, false) return nil @@ -220,7 +220,7 @@ func (server *Server) GetConfig() (string, error) { infoErr := server.APIInfo() if infoErr != nil { - return "", infoErr + return "", &ServerGetConfigError{Err: infoErr} } // Set the current profile if there is only one profile @@ -232,8 +232,89 @@ func (server *Server) GetConfig() (string, error) { profileErr := server.askForProfileID() if profileErr != nil { - return "", nil + return "", &ServerGetConfigError{Err: profileErr} } return server.getConfigWithProfile() } + +type ServerGetCurrentProfileNotFoundError struct { + ProfileID string +} + +func (e *ServerGetCurrentProfileNotFoundError) Error() string { + return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID) +} + +type ServerGetConfigWithProfileError struct { + Err error +} + +func (e *ServerGetConfigWithProfileError) Error() string { + return fmt.Sprintf("failed to get config including profile with error %v", e.Err) +} + +type ServerGetEndpointsError struct { + Err error +} + +func (e *ServerGetEndpointsError) Error() string { + return fmt.Sprintf("failed to get server endpoint with error %v", e.Err) +} + +type ServerGetSecureInternetHomeError struct{} + +func (e *ServerGetSecureInternetHomeError) Error() string { + return "failed to get secure internet home server, not found" +} + +type ServerCopySecureInternetOAuthError struct { + Err error +} + +func (e *ServerCopySecureInternetOAuthError) Error() string { + return fmt.Sprintf("failed to copy oauth tokens from home server with error %v", e.Err) +} + +type ServerEnsureServerEmptyURLError struct{} + +func (e *ServerEnsureServerEmptyURLError) Error() string { + return "failed ensuring server, empty url provided" +} + +type ServerEnsureServerError struct { + Err error +} + +func (e *ServerEnsureServerError) Error() string { + return fmt.Sprintf("failed ensuring server with error %v", e.Err) +} + +type ServerGetCurrentNoMapError struct{} + +func (e *ServerGetCurrentNoMapError) Error() string { + return "failed getting current server, no servers available" +} + +type ServerGetCurrentNotFoundError struct{} + +func (e *ServerGetCurrentNotFoundError) Error() string { + return "failed getting current server, not found" +} + +type ServerGetConfigError struct { + Err error +} + +func (e *ServerGetConfigError) Error() string { + return fmt.Sprintf("failed getting server config with error %v", e.Err) +} + +type ServerInitializeError struct { + URL string + Err error +} + +func (e *ServerInitializeError) Error() string { + return fmt.Sprintf("failed initializing server with url %s and error %v", e.URL, e.Err) +} diff --git a/internal/verify.go b/internal/verify.go index 9128777..713e4d7 100644 --- a/internal/verify.go +++ b/internal/verify.go @@ -58,13 +58,81 @@ func InsecureTestingSetExtraKey(keyString string) { extraKey = keyString } +// verifyWithKeys verifies the Minisign signature in signatureFileContent (minisig file format) over the server_list/organization_list JSON in signedJson. +// +// Verification is performed using a matching key in allowedPublicKeys. +// The signature is checked to be a Ed25519 Minisign (optionally Ed25519 Blake2b-512 prehashed, see forcePrehash) signature with a valid trusted comment. +// The file type that is verified is indicated by expectedFileName, which must be one of "server_list.json"/"organization_list.json". +// The trusted comment is checked to be of the form "timestamp:<timestamp>\tfile:<expectedFileName>", optionally suffixed by something, e.g. "\thashed". +// The signature is checked to have a timestamp with a value of at least minSignTime, which is a UNIX timestamp without milliseconds. +// +// The return value will either be (true, nil) on success or (false, detailedVerifyError) on failure. +func verifyWithKeys(signatureFileContent string, signedJson []byte, filename string, minSignTime uint64, allowedPublicKeys []string, forcePrehash bool) (bool, error) { + switch filename { + case "server_list.json", "organization_list.json": + break + default: + return false, &VerifyUnknownExpectedFilenameError{Filename: filename, Expected: "server_list.json or organization_list.json"} + } + + sig, err := minisign.DecodeSignature(signatureFileContent) + if err != nil { + return false, &VerifyInvalidSignatureFormatError{Err: err} + } + + // Check if signature is prehashed, see https://jedisct1.github.io/minisign/#signature-format + if forcePrehash && sig.SignatureAlgorithm != [2]byte{'E', 'D'} { + return false, &VerifyInvalidSignatureAlgorithmError{Algorithm: string(sig.SignatureAlgorithm[:]), WantedAlgorithm: "ED (BLAKE2b-prehashed EdDSA)"} + } + + // Find allowed key used for signature + for _, keyStr := range allowedPublicKeys { + key, err := minisign.NewPublicKey(keyStr) + if err != nil { + // Should only happen if Verify is wrong or extraKey is invalid + return false, &VerifyCreatePublicKeyError{PublicKey: keyStr, Err: err} + } + + if sig.KeyId != key.KeyId { + continue // Wrong key + } + + valid, err := key.Verify(signedJson, sig) + if !valid { + return false, &VerifyInvalidSignatureError{Err: err} + } + + // Parse trusted comment + var signTime uint64 + var sigFileName string + // sigFileName cannot have spaces + _, err = fmt.Sscanf(sig.TrustedComment, "trusted comment: timestamp:%d\tfile:%s", &signTime, &sigFileName) + if err != nil { + return false, &VerifyInvalidTrustedCommentError{TrustedComment: sig.TrustedComment, Err: err} + } + + if sigFileName != filename { + return false, &VerifyWrongSigFilenameError{Filename: filename, SigFilename: sigFileName} + } + + if signTime < minSignTime { + return false, &VerifySigTimeEarlierError{SigTime: signTime, MinSigTime: minSignTime} + } + + return true, nil + } + + // No matching allowed key found + return false, &VerifyUnknownKeyError{Filename: filename} +} + type VerifyUnknownExpectedFilenameError struct { Filename string Expected string } func (e *VerifyUnknownExpectedFilenameError) Error() string { - return fmt.Sprintf("invalid filename %s, expected %s", e.Filename, e.Expected) + return fmt.Sprintf("invalid filename: %s, expected: %s", e.Filename, e.Expected) } type VerifyInvalidSignatureFormatError struct { @@ -72,7 +140,7 @@ type VerifyInvalidSignatureFormatError struct { } func (e *VerifyInvalidSignatureFormatError) Error() string { - return fmt.Sprintf("invalid signature format, error %v", e.Err) + return fmt.Sprintf("invalid signature format with error: %v", e.Err) } type VerifyInvalidSignatureAlgorithmError struct { @@ -81,7 +149,7 @@ type VerifyInvalidSignatureAlgorithmError struct { } func (e *VerifyInvalidSignatureAlgorithmError) Error() string { - return fmt.Sprintf("invalid signature algorithm %s, wanted %s", e.Algorithm, e.WantedAlgorithm) + return fmt.Sprintf("invalid signature algorithm: %s, wanted: %s", e.Algorithm, e.WantedAlgorithm) } type VerifyCreatePublicKeyError struct { @@ -90,7 +158,7 @@ type VerifyCreatePublicKeyError struct { } func (e *VerifyCreatePublicKeyError) Error() string { - return fmt.Sprintf("failed to create public key %s with error %v", e.PublicKey, e.Err) + return fmt.Sprintf("failed to create public key: %s with error: %v", e.PublicKey, e.Err) } type VerifyInvalidSignatureError struct { @@ -98,7 +166,7 @@ type VerifyInvalidSignatureError struct { } func (e *VerifyInvalidSignatureError) Error() string { - return fmt.Sprintf("invalid signature with error %v", e.Err) + return fmt.Sprintf("invalid signature with error: %v", e.Err) } type VerifyInvalidTrustedCommentError struct { @@ -107,7 +175,7 @@ type VerifyInvalidTrustedCommentError struct { } func (e *VerifyInvalidTrustedCommentError) Error() string { - return fmt.Sprintf("invalid trusted comment %s with error %v", e.TrustedComment, e.Err) + return fmt.Sprintf("invalid trusted comment: %s with error: %v", e.TrustedComment, e.Err) } type VerifyWrongSigFilenameError struct { @@ -116,7 +184,7 @@ type VerifyWrongSigFilenameError struct { } func (e *VerifyWrongSigFilenameError) Error() string { - return fmt.Sprintf("wrong filename %s, expected filename %s for signature", e.Filename, e.SigFilename) + return fmt.Sprintf("wrong filename: %s, expected filename: %s for signature", e.Filename, e.SigFilename) } type VerifySigTimeEarlierError struct { @@ -125,7 +193,7 @@ type VerifySigTimeEarlierError struct { } func (e *VerifySigTimeEarlierError) Error() string { - return fmt.Sprintf("Sign time %d is earlier than sign time %d", e.SigTime, e.MinSigTime) + return fmt.Sprintf("Sign time: %d is earlier than sign time: %d", e.SigTime, e.MinSigTime) } type VerifyUnknownKeyError struct { @@ -133,73 +201,5 @@ type VerifyUnknownKeyError struct { } func (e *VerifyUnknownKeyError) Error() string { - return fmt.Sprintf("signature for filename %s was created with an unknown key", e.Filename) -} - -// verifyWithKeys verifies the Minisign signature in signatureFileContent (minisig file format) over the server_list/organization_list JSON in signedJson. -// -// Verification is performed using a matching key in allowedPublicKeys. -// The signature is checked to be a Ed25519 Minisign (optionally Ed25519 Blake2b-512 prehashed, see forcePrehash) signature with a valid trusted comment. -// The file type that is verified is indicated by expectedFileName, which must be one of "server_list.json"/"organization_list.json". -// The trusted comment is checked to be of the form "timestamp:<timestamp>\tfile:<expectedFileName>", optionally suffixed by something, e.g. "\thashed". -// The signature is checked to have a timestamp with a value of at least minSignTime, which is a UNIX timestamp without milliseconds. -// -// The return value will either be (true, nil) on success or (false, detailedVerifyError) on failure. -func verifyWithKeys(signatureFileContent string, signedJson []byte, filename string, minSignTime uint64, allowedPublicKeys []string, forcePrehash bool) (bool, error) { - switch filename { - case "server_list.json", "organization_list.json": - break - default: - return false, &VerifyUnknownExpectedFilenameError{Filename: filename, Expected: "server_list.json or organization_list.json"} - } - - sig, err := minisign.DecodeSignature(signatureFileContent) - if err != nil { - return false, &VerifyInvalidSignatureFormatError{Err: err} - } - - // Check if signature is prehashed, see https://jedisct1.github.io/minisign/#signature-format - if forcePrehash && sig.SignatureAlgorithm != [2]byte{'E', 'D'} { - return false, &VerifyInvalidSignatureAlgorithmError{Algorithm: string(sig.SignatureAlgorithm[:]), WantedAlgorithm: "ED (BLAKE2b-prehashed EdDSA)"} - } - - // Find allowed key used for signature - for _, keyStr := range allowedPublicKeys { - key, err := minisign.NewPublicKey(keyStr) - if err != nil { - // Should only happen if Verify is wrong or extraKey is invalid - return false, &VerifyCreatePublicKeyError{PublicKey: keyStr, Err: err} - } - - if sig.KeyId != key.KeyId { - continue // Wrong key - } - - valid, err := key.Verify(signedJson, sig) - if !valid { - return false, &VerifyInvalidSignatureError{Err: err} - } - - // Parse trusted comment - var signTime uint64 - var sigFileName string - // sigFileName cannot have spaces - _, err = fmt.Sscanf(sig.TrustedComment, "trusted comment: timestamp:%d\tfile:%s", &signTime, &sigFileName) - if err != nil { - return false, &VerifyInvalidTrustedCommentError{TrustedComment: sig.TrustedComment, Err: err} - } - - if sigFileName != filename { - return false, &VerifyWrongSigFilenameError{Filename: filename, SigFilename: sigFileName} - } - - if signTime < minSignTime { - return false, &VerifySigTimeEarlierError{SigTime: signTime, MinSigTime: minSignTime} - } - - return true, nil - } - - // No matching allowed key found - return false, &VerifyUnknownKeyError{Filename: filename} + return fmt.Sprintf("signature for filename: %s was created with an unknown key", e.Filename) } diff --git a/internal/wireguard.go b/internal/wireguard.go index 4ec12bd..7977dbc 100644 --- a/internal/wireguard.go +++ b/internal/wireguard.go @@ -8,8 +8,12 @@ import ( ) func wireguardGenerateKey() (wgtypes.Key, error) { - key, error := wgtypes.GeneratePrivateKey() - return key, error + key, keyErr := wgtypes.GeneratePrivateKey() + + if keyErr != nil { + return key, &WireguardGenerateKeyError{Err: keyErr} + } + return key, nil } // FIXME: Instead of doing a regex replace, decide if we should use a parser @@ -31,14 +35,14 @@ func (server *Server) WireguardGetConfig() (string, error) { wireguardKey, wireguardErr := wireguardGenerateKey() if wireguardErr != nil { - return "", wireguardErr + return "", &WireguardGetConfigError{Err: wireguardErr} } wireguardPublicKey := wireguardKey.PublicKey().String() configWireguard, _, configErr := server.APIConnectWireguard(profile_id, wireguardPublicKey) if configErr != nil { - return "", configErr + return "", &WireguardGetConfigError{Err: wireguardErr} } // FIXME: Store expiry @@ -50,3 +54,19 @@ func (server *Server) WireguardGetConfig() (string, error) { return configWireguardKey, nil } + +type WireguardGenerateKeyError struct { + Err error +} + +func (e *WireguardGenerateKeyError) Error() string { + return fmt.Sprintf("failed generating Wireguard key with error: %v", e.Err) +} + +type WireguardGetConfigError struct { + Err error +} + +func (e *WireguardGetConfigError) Error() string { + return fmt.Sprintf("failed getting Wireguard config with error: %v", e.Err) +} @@ -1,7 +1,7 @@ package eduvpn import ( - "errors" + "fmt" "github.com/jwijenbergh/eduvpn-common/internal" ) @@ -28,7 +28,7 @@ type VPNState struct { func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error { if !state.FSM.InState(internal.DEREGISTERED) { - return errors.New("app already registered") + return &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.DEREGISTERED} } // Initialize the logger logLevel := internal.LOG_WARNING @@ -39,7 +39,7 @@ func (state *VPNState) Register(name string, directory string, stateCallback fun loggerErr := state.Logger.Init(logLevel, name, directory) if loggerErr != nil { - return errors.New("Failed to create a logger") + return &StateRegisterError{Err: loggerErr} } // Initialize the FSM @@ -75,13 +75,13 @@ func (state *VPNState) Deregister() error { func (state *VPNState) CancelOAuth() error { if !state.FSM.InState(internal.OAUTH_STARTED) { - return errors.New("cannot cancel oauth, oauth not started") + return &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.OAUTH_STARTED} } server, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { - return serverErr + return &StateOAuthCancelError{Err: serverErr} } server.CancelOAuth() return nil @@ -89,13 +89,13 @@ func (state *VPNState) CancelOAuth() error { func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (string, error) { if state.FSM.InState(internal.DEREGISTERED) { - return "", errors.New("app not registered") + return "", &StateFSMNotRegisteredError{} } // New server chosen, ensure the server is fresh server, serverErr := state.Servers.EnsureServer(url, &state.FSM, &state.Logger, true) if serverErr != nil { - return "", serverErr + return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: serverErr} } // When we connect to secure internet, copy over the tokens from the home server @@ -118,7 +118,7 @@ func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (st // We are possibly in oauth started // So go to chosen server state.FSM.GoTransition(internal.CHOSEN_SERVER) - return "", loginErr + return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: loginErr} } } else { // OAuth was valid, ensure we are in the authorized state state.FSM.GoTransition(internal.AUTHORIZED) @@ -132,7 +132,7 @@ func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (st config, configErr := server.GetConfig() if configErr != nil { - return "", configErr + return "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: configErr} } else { state.FSM.GoTransition(internal.HAS_CONFIG) } @@ -150,27 +150,77 @@ func (state *VPNState) ConnectSecureInternet(url string) (string, error) { func (state *VPNState) GetDiscoOrganizations() (string, error) { if state.FSM.InState(internal.DEREGISTERED) { - return "", errors.New("app not registered") + return "", &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.DEREGISTERED} } return state.Discovery.GetOrganizationsList() } func (state *VPNState) GetDiscoServers() (string, error) { if state.FSM.InState(internal.DEREGISTERED) { - return "", errors.New("app not registered") + return "", &StateFSMNotRegisteredError{} } return state.Discovery.GetServersList() } func (state *VPNState) SetProfileID(profileID string) error { if !state.FSM.InState(internal.ASK_PROFILE) { - return errors.New("Invalid state for setting a profile") + return &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.ASK_PROFILE} } server, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { - return errors.New("No server found for setting a profile ID") + return &StateSetProfileError{ProfileID: profileID, Err: serverErr} } server.Profiles.Current = profileID return nil } + +type StateSetProfileError struct { + ProfileID string + Err error +} + +func (e *StateSetProfileError) Error() string { + return fmt.Sprintf("failed to set profile ID %s with error %v", e.ProfileID, e.Err) +} + +type StateRegisterError struct { + Err error +} + +func (e *StateRegisterError) Error() string { + return fmt.Sprintf("failed to register with error %v", e.Err) +} + +type StateFSMNotRegisteredError struct{} + +func (e *StateFSMNotRegisteredError) Error() string { + return fmt.Sprintf("state is not registered. Current FSM state: %s", internal.DEREGISTERED.String()) +} + +type StateWrongFSMStateError struct { + Got internal.FSMStateID + Want internal.FSMStateID +} + +func (e *StateWrongFSMStateError) Error() string { + return fmt.Sprintf("wrong FSM state, got: %s, want: %s", e.Got.String(), e.Want.String()) +} + +type StateOAuthCancelError struct { + Err error +} + +func (e *StateOAuthCancelError) Error() string { + return fmt.Sprintf("failed cancelling OAuth for state with error: %v", e.Err) +} + +type StateConnectError struct { + URL string + IsSecureInternet bool + Err error +} + +func (e *StateConnectError) Error() string { + return fmt.Sprintf("failed connecting to server: %s (is secure internet: %v) with error: %v", e.URL, e.IsSecureInternet, e.Err) +} diff --git a/state_test.go b/state_test.go index 4320a6d..c6e33e0 100644 --- a/state_test.go +++ b/state_test.go @@ -84,15 +84,43 @@ func test_connect_oauth_parameter(t *testing.T, parameters internal.URLParameter }, false) _, configErr := state.ConnectInstituteAccess(serverURI) - if !errors.As(configErr, expectedErr) { - t.Errorf("error %T = %v, wantErr %T", configErr, configErr, expectedErr) + var stateErr *StateConnectError + var loginErr *internal.OAuthLoginError + var finishErr *internal.OAuthFinishError + + // We go through the chain of errors by unwrapping them one by one + + // First ensure we get a state connect error + if !errors.As(configErr, &stateErr) { + t.Errorf("error %T = %v, wantErr %T", configErr, configErr, stateErr) + } + + // Then ensure we get a login error + gotLoginErr := stateErr.Err + + if !errors.As(gotLoginErr, &loginErr) { + t.Errorf("error %T = %v, wantErr %T", gotLoginErr, gotLoginErr, loginErr) + } + + // Then ensure we get a finish error + gotFinishErr := loginErr.Err + + if !errors.As(gotFinishErr, &finishErr) { + t.Errorf("error %T = %v, wantErr %T", gotFinishErr, gotFinishErr, finishErr) + } + + // Then ensure we get the expected inner error + gotExpectedErr := finishErr.Err + + if !errors.As(gotExpectedErr, expectedErr) { + t.Errorf("error %T = %v, wantErr %T", gotExpectedErr, gotExpectedErr, expectedErr) } } func Test_connect_oauth_parameters(t *testing.T) { var ( - failedCallbackParameterError *internal.OAuthFailedCallbackParameterError - failedCallbackStateMatchError *internal.OAuthFailedCallbackStateMatchError + failedCallbackParameterError *internal.OAuthCallbackParameterError + failedCallbackStateMatchError *internal.OAuthCallbackStateMatchError ) tests := []struct { |
