diff options
| -rw-r--r-- | cmd/cli/main.go | 4 | ||||
| -rw-r--r-- | exports/exports.go | 1 | ||||
| -rw-r--r-- | internal/config/config.go | 25 | ||||
| -rw-r--r-- | internal/discovery/discovery.go | 61 | ||||
| -rw-r--r-- | internal/fsm/fsm.go | 21 | ||||
| -rw-r--r-- | internal/http/http.go | 55 | ||||
| -rw-r--r-- | internal/log/log.go | 18 | ||||
| -rw-r--r-- | internal/oauth/oauth.go | 144 | ||||
| -rw-r--r-- | internal/server/api.go | 78 | ||||
| -rw-r--r-- | internal/server/server.go | 152 | ||||
| -rw-r--r-- | internal/types/error.go | 62 | ||||
| -rw-r--r-- | internal/util/util.go | 7 | ||||
| -rw-r--r-- | internal/verify/verify.go | 21 | ||||
| -rw-r--r-- | internal/wireguard/wireguard.go | 12 | ||||
| -rw-r--r-- | state.go | 91 | ||||
| -rw-r--r-- | state_test.go | 31 |
16 files changed, 283 insertions, 500 deletions
diff --git a/cmd/cli/main.go b/cmd/cli/main.go index a9bff2e..e6be0bf 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -205,7 +205,9 @@ func printConfig(url string, isInstitute bool) { config, _, configErr := getConfig(state, url, isInstitute) if configErr != nil { - fmt.Println("Error getting config", configErr) + // Show the usage of tracebacks and causes + fmt.Println("Error getting config:", state.GetErrorTraceback(configErr)) + fmt.Println("Error getting config, cause:", state.GetErrorCause(configErr)) return } diff --git a/exports/exports.go b/exports/exports.go index 2139254..79b3606 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -17,6 +17,7 @@ import ( "errors" "fmt" "unsafe" + "github.com/jwijenbergh/eduvpn-common" ) diff --git a/internal/config/config.go b/internal/config/config.go index a9ebec7..a74bcd8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -5,6 +5,8 @@ import ( "fmt" "io/ioutil" "path" + + "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" ) @@ -24,13 +26,14 @@ func (config *Config) GetFilename() string { } func (config *Config) Save(readStruct interface{}) error { + errorMessage := "failed saving configuration" configDirErr := util.EnsureDirectory(config.Directory) if configDirErr != nil { - return &ConfigSaveError{Err: configDirErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: configDirErr} } jsonString, marshalErr := json.Marshal(readStruct) if marshalErr != nil { - return &ConfigSaveError{Err: marshalErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: marshalErr} } return ioutil.WriteFile(config.GetFilename(), jsonString, 0o600) } @@ -38,23 +41,7 @@ func (config *Config) Save(readStruct interface{}) error { func (config *Config) Load(writeStruct interface{}) error { bytes, readErr := ioutil.ReadFile(config.GetFilename()) if readErr != nil { - return &ConfigLoadError{Err: readErr} + return &types.WrappedErrorMessage{Message: "failed loading configuration", Err: readErr} } return json.Unmarshal(bytes, writeStruct) } - -type ConfigSaveError struct { - Err error -} - -func (e *ConfigSaveError) Error() string { - return fmt.Sprintf("failed to save config with error: %v", e.Err) -} - -type ConfigLoadError struct { - Err error -} - -func (e *ConfigLoadError) Error() string { - return fmt.Sprintf("failed to load config with error: %v", e.Err) -} diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index d72b4a6..ac3bf57 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -3,50 +3,15 @@ package discovery import ( "encoding/json" "fmt" + "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/http" "github.com/jwijenbergh/eduvpn-common/internal/log" + "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" "github.com/jwijenbergh/eduvpn-common/internal/verify" ) -type DiscoFileError struct { - URL string - Err error -} - -func (e *DiscoFileError) Error() string { - return fmt.Sprintf("failed obtaining disco file %s with error %v", e.URL, e.Err) -} - -type DiscoSigFileError struct { - URL string - Err error -} - -func (e *DiscoSigFileError) Error() string { - return fmt.Sprintf("failed obtaining disco signature file %s with error %v", e.URL, e.Err) -} - -type DiscoVerifyError struct { - File string - Sigfile string - Err error -} - -func (e *DiscoVerifyError) Error() string { - return fmt.Sprintf("failed verifying file %s with signature %s due to error %v", e.File, e.Sigfile, e.Err) -} - -type DiscoJSONError struct { - Body string - Err error -} - -func (e *DiscoJSONError) Error() string { - return fmt.Sprintf("failed parsing JSON for contents %s with error %v", e.Body, e.Err) -} - type OrganizationList struct { JSON json.RawMessage `json:"organization_list"` Version uint64 `json:"v"` @@ -68,13 +33,14 @@ type Discovery struct { // Helper function that gets a disco json func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) error { + errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile) // Get json data discoURL := "https://disco.eduvpn.org/v2/" fileURL := discoURL + jsonFile _, fileBody, fileErr := http.HTTPGet(fileURL) if fileErr != nil { - return &DiscoFileError{fileURL, fileErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr} } // Get signature @@ -83,7 +49,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} _, sigBody, sigFileErr := http.HTTPGet(sigURL) if sigFileErr != nil { - return &DiscoSigFileError{URL: sigURL, Err: sigFileErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr} } // Verify signature @@ -92,28 +58,19 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} verifySuccess, verifyErr := verify.Verify(string(sigBody), fileBody, jsonFile, previousVersion, forcePrehash) if !verifySuccess || verifyErr != nil { - return &DiscoVerifyError{File: jsonFile, Sigfile: sigFile, Err: verifyErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr} } // Parse JSON to extract version and list jsonErr := json.Unmarshal(fileBody, structure) if jsonErr != nil { - return &DiscoJSONError{Body: string(fileBody), Err: jsonErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} } return nil } -type GetListError struct { - File string - Err error -} - -func (e *GetListError) Error() string { - return fmt.Sprintf("failed getting disco list file %s with error %v", e.File, e.Err) -} - func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) { discovery.FSM = fsm discovery.Logger = logger @@ -155,7 +112,7 @@ func (discovery *Discovery) GetOrganizationsList() (string, error) { err := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations) if err != nil { // Return previous with an error - return string(discovery.Organizations.JSON), &GetListError{File: file, Err: err} + return string(discovery.Organizations.JSON), &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: err} } return string(discovery.Organizations.JSON), nil } @@ -169,7 +126,7 @@ func (discovery *Discovery) GetServersList() (string, error) { err := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers) if err != nil { // Return previous with an error - return string(discovery.Servers.JSON), &GetListError{File: file, Err: err} + return string(discovery.Servers.JSON), &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: err} } // Update servers timestamp discovery.Servers.Timestamp = util.GenerateTimeSeconds() diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go index bb7f330..3d0bfd6 100644 --- a/internal/fsm/fsm.go +++ b/internal/fsm/fsm.go @@ -1,12 +1,15 @@ package fsm import ( + "errors" "fmt" "os" "os/exec" "path" "sort" + "github.com/jwijenbergh/eduvpn-common/internal/log" + "github.com/jwijenbergh/eduvpn-common/internal/types" ) type ( @@ -209,20 +212,26 @@ func (fsm *FSM) GenerateGraph() string { return fsm.generateMermaidGraph() } -type FSMWrongStateTransitionError struct { +type DeregisteredError struct{} + +func (e DeregisteredError) CustomError() *types.WrappedErrorMessage { + return &types.WrappedErrorMessage{Message: "Client not registered with the GO library", Err: errors.New("the current FSM state is deregistered, but the function needs a state that is not deregistered")} +} + +type WrongStateTransitionError struct { Got FSMStateID Want FSMStateID } -func (e *FSMWrongStateTransitionError) Error() string { - return fmt.Sprintf("wrong FSM state, got: %s, want a state with a transition to: %s", e.Got.String(), e.Want.String()) +func (e WrongStateTransitionError) CustomError() *types.WrappedErrorMessage { + return &types.WrappedErrorMessage{Message: "Wrong FSM transition", Err: errors.New(fmt.Sprintf("wrong FSM state, got: %s, want: a state with a transition to: %s", e.Got.String(), e.Want.String()))} } -type FSMWrongStateError struct { +type WrongStateError struct { Got FSMStateID Want FSMStateID } -func (e *FSMWrongStateError) Error() string { - return fmt.Sprintf("wrong FSM state, got: %s, want: %s", e.Got.String(), e.Want.String()) +func (e WrongStateError) CustomError() *types.WrappedErrorMessage { + return &types.WrappedErrorMessage{Message: "Wrong FSM State", Err: errors.New(fmt.Sprintf("wrong FSM state, got: %s, want: %s", e.Got.String(), e.Want.String()))} } diff --git a/internal/http/http.go b/internal/http/http.go index 87346f1..3c8e4e1 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -7,6 +7,8 @@ import ( "net/http" "net/url" "strings" + + "github.com/jwijenbergh/eduvpn-common/internal/types" ) type URLParameters map[string]string @@ -21,7 +23,7 @@ type HTTPOptionalParams struct { func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) { url, parseErr := url.Parse(baseURL) if parseErr != nil { - return "", &HTTPConstructURLError{URL: baseURL, Parameters: parameters, Err: parseErr} + return "", &types.WrappedErrorMessage{Message: fmt.Sprintf("failed to construct url: %s including parameters: %v", url, parameters), Err: parseErr} } q := url.Query() @@ -55,7 +57,7 @@ func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) { url, urlErr := HTTPConstructURL(url, opts.URLParameters) if urlErr != nil { - return url, &HTTPRequestCreateError{URL: url, Err: urlErr} + return url, &types.WrappedErrorMessage{Message: fmt.Sprintf("failed to create HTTP request with url: %s", url), Err: urlErr} } return url, nil } @@ -91,10 +93,12 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht // Create a client client := &http.Client{} + errorMessage := fmt.Sprintf("failed HTTP request with method %s and url %s", method, url) + // Create request object with the body reader generated from the optional arguments req, reqErr := http.NewRequest(method, url, httpOptionalBodyReader(opts)) if reqErr != nil { - return nil, nil, &HTTPRequestCreateError{URL: url, Err: reqErr} + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: reqErr} } // See https://stackoverflow.com/questions/17714494/golang-http-request-results-in-eof-errors-when-making-multiple-requests-successi @@ -106,7 +110,7 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht // Do request resp, respErr := client.Do(req) if respErr != nil { - return nil, nil, &HTTPResourceError{URL: url, Err: respErr} + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: respErr} } // Request successful, make sure body is closed at the end @@ -115,26 +119,19 @@ func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) (ht // Return a string body, readErr := ioutil.ReadAll(resp.Body) if readErr != nil { - return resp.Header, nil, &HTTPReadError{URL: url, Err: readErr} + return resp.Header, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: readErr} } if resp.StatusCode < 200 || resp.StatusCode > 299 { - return resp.Header, body, &HTTPStatusError{URL: url, Status: resp.StatusCode} + // We make this a custom error because we want to extract the status code later + statusErr := &HTTPStatusError{URL: url, Status: resp.StatusCode} + return resp.Header, body, &types.WrappedErrorMessage{Message: errorMessage, Err: statusErr} } // Return the body in bytes and signal the status error if there was one return resp.Header, body, nil } -type HTTPResourceError struct { - URL string - Err error -} - -func (e *HTTPResourceError) Error() string { - return fmt.Sprintf("failed obtaining HTTP resource: %s with error: %v", e.URL, e.Err) -} - type HTTPStatusError struct { URL string Status int @@ -144,15 +141,6 @@ func (e *HTTPStatusError) Error() string { return fmt.Sprintf("failed obtaining HTTP resource: %s as it gave an unsuccesful status code: %d", e.URL, e.Status) } -type HTTPReadError struct { - URL string - Err error -} - -func (e *HTTPReadError) Error() string { - return fmt.Sprintf("failed reading HTTP resource: %s with error: %v", e.URL, e.Err) -} - type HTTPParseJsonError struct { URL string Body string @@ -162,22 +150,3 @@ type HTTPParseJsonError struct { func (e *HTTPParseJsonError) Error() string { return fmt.Sprintf("failed parsing json %s for HTTP resource: %s with error: %v", e.Body, e.URL, e.Err) } - -type HTTPRequestCreateError struct { - URL string - Err error -} - -func (e *HTTPRequestCreateError) Error() string { - return fmt.Sprintf("failed to create HTTP request with url: %s and error: %v", e.URL, e.Err) -} - -type HTTPConstructURLError struct { - URL string - Parameters URLParameters - Err error -} - -func (e *HTTPConstructURLError) Error() string { - return fmt.Sprintf("failed to construct url: %s including parameters: %v with error: %v", e.URL, e.Parameters, e.Err) -} diff --git a/internal/log/log.go b/internal/log/log.go index cba3364..f4024e2 100644 --- a/internal/log/log.go +++ b/internal/log/log.go @@ -5,6 +5,8 @@ import ( "log" "os" "path" + + "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" ) @@ -38,13 +40,15 @@ func (e LogLevel) String() string { } func (logger *FileLogger) Init(level LogLevel, name string, directory string) error { + errorMessage := "failed creating log" + configDirErr := util.EnsureDirectory(directory) if configDirErr != nil { - return &LogInitializeError{Name: name, Directory: directory, Err: configDirErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: configDirErr} } logFile, logOpenErr := os.OpenFile(logger.getFilename(directory, name), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0o666) if logOpenErr != nil { - return &LogInitializeError{Name: name, Directory: directory, Err: logOpenErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: logOpenErr} } log.SetOutput(logFile) logger.File = logFile @@ -66,13 +70,3 @@ func (logger *FileLogger) Log(level LogLevel, str string) { func (logger *FileLogger) Close() { logger.File.Close() } - -type LogInitializeError struct { - Name string - Directory string - Err error -} - -func (e *LogInitializeError) Error() string { - return fmt.Sprintf("failed initializing logging with name: %s and directory: %s with error: %v", e.Name, e.Directory, e.Err) -} diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index f6ed916..824db90 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -8,10 +8,12 @@ import ( "fmt" "net/http" "net/url" + "github.com/jwijenbergh/eduvpn-common/internal/fsm" httpw "github.com/jwijenbergh/eduvpn-common/internal/http" - "github.com/jwijenbergh/eduvpn-common/internal/util" "github.com/jwijenbergh/eduvpn-common/internal/log" + "github.com/jwijenbergh/eduvpn-common/internal/types" + "github.com/jwijenbergh/eduvpn-common/internal/util" ) // Generates a random base64 string to be used for state @@ -23,7 +25,7 @@ import ( func genState() (string, error) { randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { - return "", &OAuthGenStateError{Err: err} + return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth state", Err: err} } // For consistency we also use raw url encoding here @@ -49,7 +51,7 @@ func genChallengeS256(verifier string) string { func genVerifier() (string, error) { randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { - return "", &OAuthGenVerifierError{Err: err} + return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth verifier", Err: err} } return base64.RawURLEncoding.EncodeToString(randomBytes), nil @@ -60,8 +62,8 @@ type OAuth struct { Token OAuthToken `json:"token"` BaseAuthorizationURL string `json:"base_authorization_url"` TokenURL string `json:"token_url"` - Logger *log.FileLogger `json:"-"` - FSM *fsm.FSM `json:"-"` + Logger *log.FileLogger `json:"-"` + FSM *fsm.FSM `json:"-"` } // This structure gets passed to the callback for easy access to the current state @@ -99,7 +101,7 @@ func (oauth *OAuth) getTokensWithCallback() error { } mux.HandleFunc("/callback", oauth.Callback) if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed { - return &OAuthCallbackError{Addr: addr, Err: err} + return &types.WrappedErrorMessage{Message: "failed getting tokens with callback", Err: err} } return oauth.Session.CallbackError } @@ -108,9 +110,9 @@ func (oauth *OAuth) getTokensWithCallback() error { // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { + errorMessage := "failed getting tokens with the authorization code" // Make sure the verifier is set as the parameter // so that the server can verify that we are the actual owner of the authorization code - reqURL := oauth.TokenURL data := url.Values{ "client_id": {oauth.Session.ClientID}, @@ -126,7 +128,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { current_time := util.GenerateTimeSeconds() _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return &OAuthAuthError{Err: bodyErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} } tokenStructure := OAuthToken{} @@ -134,7 +136,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { - return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}} } tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires @@ -152,6 +154,7 @@ func (oauth *OAuth) isTokensExpired() bool { // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 func (oauth *OAuth) getTokensWithRefresh() error { + errorMessage := "failed getting tokens with the refresh token" reqURL := oauth.TokenURL data := url.Values{ "refresh_token": {oauth.Token.Refresh}, @@ -164,14 +167,14 @@ func (oauth *OAuth) getTokensWithRefresh() error { current_time := util.GenerateTimeSeconds() _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return &OAuthRefreshError{Err: bodyErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} } tokenStructure := OAuthToken{} jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { - return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}} } tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires @@ -182,11 +185,15 @@ func (oauth *OAuth) getTokensWithRefresh() error { // //// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { + errorMessage := "failed callback to retrieve the authorization code" // Extract the authorization code code, success := req.URL.Query()["code"] - if !success { - oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()} + // Shutdown after we're done + defer func() { go oauth.Session.Server.Shutdown(oauth.Session.Context) + }() + if !success { + oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}} return } // The code is the first entry @@ -195,30 +202,25 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // Make sure the state is present and matches to protect against cross-site request forgeries // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 state, success := req.URL.Query()["state"] + if !success { - oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()} - go oauth.Session.Server.Shutdown(oauth.Session.Context) + oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}} return } // The state is the first entry extractedState := state[0] if extractedState != oauth.Session.State { - oauth.Session.CallbackError = &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State} - go oauth.Session.Server.Shutdown(oauth.Session.Context) + oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State}} return } // Now that we have obtained the authorization code, we can move to the next step: // Obtaining the access and refresh tokens - err := oauth.getTokensWithAuthCode(extractedCode) - if err != nil { - oauth.Session.CallbackError = &OAuthCallbackGetTokensError{Err: err} - go oauth.Session.Server.Shutdown(oauth.Session.Context) + getTokensErr := oauth.getTokensWithAuthCode(extractedCode) + if getTokensErr != nil { + oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: getTokensErr} return } - - // Shutdown the server as we're done listening - go oauth.Session.Server.Shutdown(oauth.Session.Context) } func (oauth *OAuth) Update(fsm *fsm.FSM, logger *log.FileLogger) { @@ -235,19 +237,20 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm. // Starts the OAuth exchange for eduvpn. func (oauth *OAuth) start(name string) error { + errorMessage := "failed starting OAuth exchange" if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) { - return &fsm.FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED} + return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()} } // Generate the state state, stateErr := genState() if stateErr != nil { - return &OAuthInitializeError{Err: stateErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr} } // Generate the verifier and challenge verifier, verifierErr := genVerifier() if verifierErr != nil { - return &OAuthInitializeError{Err: verifierErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr} } challenge := genChallengeS256(verifier) @@ -264,7 +267,7 @@ func (oauth *OAuth) start(name string) error { authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters) if urlErr != nil { - return &OAuthInitializeError{Err: urlErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client @@ -277,34 +280,36 @@ func (oauth *OAuth) start(name string) error { // Error definitions func (oauth *OAuth) Finish() error { + errorMessage := "failed finishing OAuth" if !oauth.FSM.HasTransition(fsm.AUTHORIZED) { - return &fsm.FSMWrongStateError{Got: oauth.FSM.Current, Want: fsm.AUTHORIZED} + return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.AUTHORIZED}.CustomError()} } tokenErr := oauth.getTokensWithCallback() if tokenErr != nil { - return &OAuthFinishError{Err: tokenErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: tokenErr} } oauth.FSM.GoTransition(fsm.AUTHORIZED) return nil } func (oauth *OAuth) Cancel() { - oauth.Session.CallbackError = &OAuthCancelledCallbackError{} + oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: "failed cancelling OAuth", Err: &OAuthCancelledCallbackError{}} oauth.Session.Server.Shutdown(oauth.Session.Context) } func (oauth *OAuth) Login(name string) error { + errorMessage := "failed OAuth login" authInitializeErr := oauth.start(name) if authInitializeErr != nil { - return &OAuthLoginError{Err: authInitializeErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr} } oauthErr := oauth.Finish() if oauthErr != nil { - return &OAuthLoginError{Err: oauthErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr} } return nil } @@ -342,31 +347,6 @@ func (e *OAuthCancelledCallbackError) Error() string { return fmt.Sprintf("client cancelled OAuth") } -type OAuthGenStateError struct { - Err error -} - -func (e *OAuthGenStateError) Error() string { - return fmt.Sprintf("failed generating state with error: %v", e.Err) -} - -type OAuthGenVerifierError struct { - Err error -} - -func (e *OAuthGenVerifierError) Error() string { - return fmt.Sprintf("failed generating verifier with error: %v", e.Err) -} - -type OAuthCallbackError struct { - Addr string - Err error -} - -func (e *OAuthCallbackError) Error() string { - return fmt.Sprintf("failed callback: %s with error: %v", e.Addr, e.Err) -} - type OAuthCallbackParameterError struct { Parameter string URL string @@ -384,51 +364,3 @@ type OAuthCallbackStateMatchError struct { func (e *OAuthCallbackStateMatchError) Error() string { return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState) } - -type OAuthCallbackGetTokensError struct { - Err error -} - -func (e *OAuthCallbackGetTokensError) Error() string { - return fmt.Sprintf("failed getting tokens with error: %v", e.Err) -} - -type OAuthFinishError struct { - Err error -} - -func (e *OAuthFinishError) Error() string { - return fmt.Sprintf("failed finishing OAuth with error: %v", e.Err) -} - -type OAuthLoginError struct { - Err error -} - -func (e *OAuthLoginError) Error() string { - return fmt.Sprintf("failed OAuth logging in with error: %v", e.Err) -} - -type OAuthInitializeError struct { - Err error -} - -func (e *OAuthInitializeError) Error() string { - return fmt.Sprintf("failed initializing OAuth with error: %v", e.Err) -} - -type OAuthAuthError struct { - Err error -} - -func (e *OAuthAuthError) Error() string { - return fmt.Sprintf("failed getting tokens with auth code for OAuth with error: %v", e.Err) -} - -type OAuthRefreshError struct { - Err error -} - -func (e *OAuthRefreshError) Error() string { - return fmt.Sprintf("failed refreshing tokens for OAuth with error: %v", e.Err) -} diff --git a/internal/server/api.go b/internal/server/api.go index 96bd641..c8c7180 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -6,32 +6,34 @@ import ( "fmt" "net/http" "net/url" + httpw "github.com/jwijenbergh/eduvpn-common/internal/http" "github.com/jwijenbergh/eduvpn-common/internal/log" + "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" ) func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) { + errorMessage := "failed getting server endpoints" url := fmt.Sprintf("%s/%s", baseURL, WellKnownPath) _, body, bodyErr := httpw.HTTPGet(url) if bodyErr != nil { - return nil, &APIGetEndpointsError{Err: bodyErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} } endpoints := &ServerEndpoints{} jsonErr := json.Unmarshal(body, endpoints) if jsonErr != nil { - return nil, &APIGetEndpointsError{Err: jsonErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} } return endpoints, nil } -// Authorized wrappers on top of HTTP -// the errors will not be wrapped here so that the caller can check if we got a status error, to retry oauth func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HTTPOptionalParams) (http.Header, []byte, error) { + errorMessage := "failed API authorized" // Ensure optional is not nil as we will fill it with headers if opts == nil { opts = &httpw.HTTPOptionalParams{} @@ -39,7 +41,7 @@ func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HT base, baseErr := server.GetBase() if baseErr != nil { - return nil, nil, baseErr + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } url := base.Endpoints.API.V3.API + endpoint @@ -52,7 +54,7 @@ func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HT base.FSM.Current = stateBefore if oauthErr != nil { - return nil, nil, oauthErr + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr} } headerKey := "Authorization" @@ -66,11 +68,12 @@ func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HT } func apiAuthorizedRetry(server Server, method string, endpoint string, opts *httpw.HTTPOptionalParams) (http.Header, []byte, error) { + errorMessage := "failed authorized API retry" header, body, bodyErr := apiAuthorized(server, method, endpoint, opts) base, baseErr := server.GetBase() if baseErr != nil { - return nil, nil, &APIAuthorizedError{Err: baseErr} + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } if bodyErr != nil { var error *httpw.HTTPStatusError @@ -82,31 +85,32 @@ func apiAuthorizedRetry(server Server, method string, endpoint string, opts *htt server.GetOAuth().Token.ExpiredTimestamp = util.GenerateTimeSeconds() retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts) if retryErr != nil { - return nil, nil, &APIAuthorizedError{Err: retryErr} + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: retryErr} } return retryHeader, retryBody, nil } - return nil, nil, &APIAuthorizedError{Err: bodyErr} + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} } return header, body, nil } func APIInfo(server Server) error { + errorMessage := "failed API /info" _, body, bodyErr := apiAuthorizedRetry(server, http.MethodGet, "/info", nil) if bodyErr != nil { - return &APIInfoError{Err: bodyErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} } structure := ServerProfileInfo{} jsonErr := json.Unmarshal(body, &structure) if jsonErr != nil { - return &APIInfoError{Err: jsonErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} } base, baseErr := server.GetBase() if baseErr != nil { - return &APIInfoError{Err: baseErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } // Store the profiles and make sure that the current profile is not overwritten @@ -118,6 +122,7 @@ func APIInfo(server Server) error { } func APIConnectWireguard(server Server, profile_id string, pubkey string, supportsOpenVPN bool) (string, string, int64, error) { + errorMessage := "failed obtaining a WireGuard configuration" headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-wireguard-profile"}, @@ -133,14 +138,14 @@ func APIConnectWireguard(server Server, profile_id string, pubkey string, suppor } header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { - return "", "", 0, &APIConnectWireguardError{Err: connectErr} + return "", "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr} } expires := header.Get("expires") pTime, pTimeErr := http.ParseTime(expires) if pTimeErr != nil { - return "", "", 0, &APIConnectWireguardError{Err: pTimeErr} + return "", "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr} } contentType := header.Get("content-type") @@ -153,6 +158,7 @@ func APIConnectWireguard(server Server, profile_id string, pubkey string, suppor } func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error) { + errorMessage := "failed obtaining an OpenVPN configuration" headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-openvpn-profile"}, @@ -164,13 +170,13 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error) header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { - return "", 0, &APIConnectOpenVPNError{Err: connectErr} + return "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr} } expires := header.Get("expires") pTime, pTimeErr := http.ParseTime(expires) if pTimeErr != nil { - return "", 0, &APIConnectOpenVPNError{Err: pTimeErr} + return "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr} } return string(connectBody), pTime.Unix(), nil } @@ -179,43 +185,3 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error) func APIDisconnect(server Server) { apiAuthorizedRetry(server, http.MethodPost, "/disconnect", nil) } - -type APIAuthorizedError struct { - Err error -} - -func (e *APIAuthorizedError) Error() string { - return fmt.Sprintf("failed api authorized call with error: %v", e.Err) -} - -type APIConnectWireguardError struct { - Err error -} - -func (e *APIConnectWireguardError) Error() string { - return fmt.Sprintf("failed api /connect wireguard call with error: %v", e.Err) -} - -type APIConnectOpenVPNError struct { - Err error -} - -func (e *APIConnectOpenVPNError) Error() string { - return fmt.Sprintf("failed api /connect OpenVPN call with error: %v", e.Err) -} - -type APIInfoError struct { - Err error -} - -func (e *APIInfoError) Error() string { - return fmt.Sprintf("failed api /info call with error: %v", e.Err) -} - -type APIGetEndpointsError struct { - Err error -} - -func (e *APIGetEndpointsError) Error() string { - return fmt.Sprintf("failed to get server endpoint with error %v", e.Err) -} diff --git a/internal/server/server.go b/internal/server/server.go index a1fb749..ce72400 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,9 +2,11 @@ package server import ( "fmt" + "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/log" "github.com/jwijenbergh/eduvpn-common/internal/oauth" + "github.com/jwijenbergh/eduvpn-common/internal/types" "github.com/jwijenbergh/eduvpn-common/internal/util" "github.com/jwijenbergh/eduvpn-common/internal/wireguard" ) @@ -17,8 +19,8 @@ type ServerBase struct { ProfilesRaw string `json:"profiles_raw"` StartTime int64 `json:"start-time"` EndTime int64 `json:"end-time"` - Logger *log.FileLogger `json:"-"` - FSM *fsm.FSM `json:"-"` + Logger *log.FileLogger `json:"-"` + FSM *fsm.FSM `json:"-"` } // An instute access server @@ -49,18 +51,19 @@ type InstituteServers struct { } func (servers *Servers) GetCurrentServer() (Server, error) { + errorMessage := "failed getting current server" if servers.IsSecureInternet { return &servers.SecureInternetHomeServer, nil } currentInstitute := servers.InstituteServers.CurrentURL institutes := servers.InstituteServers.Map if institutes == nil { - return nil, &ServerGetCurrentNoMapError{} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetCurrentNoMapError{}} } institute, exists := institutes[currentInstitute] if !exists || institute == nil { - return nil, &ServerGetCurrentNotFoundError{} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetCurrentNotFoundError{}} } return institute, nil } @@ -96,25 +99,27 @@ func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) { } func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) { + errorMessage := "failed getting current secure internet home base" if server.BaseMap == nil { - return nil, &ServerSecureInternetMapNotFoundError{} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetMapNotFoundError{}} } base, exists := server.BaseMap[server.CurrentURL] if !exists { - return nil, &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL}} } return base, nil } func (institute *InstituteAccessServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error { + errorMessage := fmt.Sprintf("failed initializing institute server %s", url) institute.Base.URL = url institute.Base.FSM = fsm institute.Base.Logger = logger endpoints, endpointsErr := APIGetEndpoints(url) if endpointsErr != nil { - return &ServerInitializeError{URL: url, Err: endpointsErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} } institute.OAuth.Init(endpoints.API.V3.Authorization, endpoints.API.V3.Token, fsm, logger) institute.Base.Endpoints = *endpoints @@ -122,6 +127,7 @@ func (institute *InstituteAccessServer) init(url string, fsm *fsm.FSM, logger *l } func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error { + errorMessage := fmt.Sprintf("failed initializing secure internet home server %s", url) // Initialize the base map if it is non-nil if secure.BaseMap == nil { secure.BaseMap = make(map[string]*ServerBase) @@ -136,7 +142,7 @@ func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *l base.URL = url endpoints, endpointsErr := APIGetEndpoints(url) if endpointsErr != nil { - return &ServerInitializeError{URL: url, Err: endpointsErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} } base.Endpoints = *endpoints } @@ -166,7 +172,7 @@ func ShouldRenewButton(server Server) (bool, error) { base, baseErr := server.GetBase() if baseErr != nil { - //return false, &GetRenewButtonTimeError{Err: baseErr} + // return false, &GetRenewButtonTimeError{Err: baseErr} return false, nil } @@ -186,7 +192,7 @@ func ShouldRenewButton(server Server) (bool, error) { // Session duration is less than 24 hours but not 75% has passed duration := base.EndTime - base.StartTime // TODO: Is converting to float64 okay here? - if duration < 24*60*60 && float64(current) <= (float64(base.StartTime) + 0.75*float64(duration)) { + if duration < 24*60*60 && float64(current) <= (float64(base.StartTime)+0.75*float64(duration)) { return false, nil } @@ -198,17 +204,18 @@ func Login(server Server) error { } func EnsureTokens(server Server) error { + errorMessage := "failed ensuring server tokens" base, baseErr := server.GetBase() if baseErr != nil { - return &ServerEnsureTokensError{Err: baseErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } if server.GetOAuth().NeedsRelogin() { base.Logger.Log(log.LOG_INFO, "OAuth: Tokens are invalid, relogging in") loginErr := Login(server) if loginErr != nil { - return &ServerEnsureTokensError{Err: loginErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr} } } return nil @@ -223,13 +230,14 @@ func CancelOAuth(server Server) { } func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) { + errorMessage := "failed ensuring server" // Intialize the secure internet server // This calls the init method which takes care of the rest if isSecureInternet { initErr := servers.SecureInternetHomeServer.init(url, fsm, logger) if initErr != nil { - return nil, &ServerEnsureServerError{Err: initErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr} } servers.IsSecureInternet = true @@ -253,7 +261,7 @@ func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm instituteServers.CurrentURL = url instituteInitErr := institute.init(url, fsm, logger) if instituteInitErr != nil { - return nil, &ServerEnsureServerError{Err: instituteInitErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: instituteInitErr} } instituteServers.Map[url] = institute servers.IsSecureInternet = false @@ -310,10 +318,11 @@ func (profile *ServerProfile) supportsOpenVPN() bool { } func getCurrentProfile(server Server) (*ServerProfile, error) { + errorMessage := "failed getting current profile" base, baseErr := server.GetBase() if baseErr != nil { - return nil, &ServerGetCurrentProfileError{Err: baseErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } profileID := base.Profiles.Current for _, profile := range base.Profiles.Info.ProfileList { @@ -321,28 +330,30 @@ func getCurrentProfile(server Server) (*ServerProfile, error) { return &profile, nil } } - return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID} + + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetCurrentProfileNotFoundError{ProfileID: profileID}} } func wireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, error) { + errorMessage := "failed getting server WireGuard configuration" base, baseErr := server.GetBase() if baseErr != nil { - return "", "", baseErr + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } profile_id := base.Profiles.Current wireguardKey, wireguardErr := wireguard.GenerateKey() if wireguardErr != nil { - return "", "", wireguardErr + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: wireguardErr} } wireguardPublicKey := wireguardKey.PublicKey().String() config, content, expires, configErr := APIConnectWireguard(server, profile_id, wireguardPublicKey, supportsOpenVPN) if configErr != nil { - return "", "", wireguardErr + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } // Store start and end time @@ -361,10 +372,11 @@ func wireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, er } func openVPNGetConfig(server Server) (string, string, error) { + errorMessage := "failed getting server OpenVPN configuration" base, baseErr := server.GetBase() if baseErr != nil { - return "", "", baseErr + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } profile_id := base.Profiles.Current configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profile_id) @@ -374,25 +386,26 @@ func openVPNGetConfig(server Server) (string, string, error) { base.EndTime = expires if configErr != nil { - return "", "", configErr + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } return configOpenVPN, "openvpn", nil } func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) { + errorMessage := "failed getting an OpenVPN/WireGuard configuration with a profile" base, baseErr := server.GetBase() if baseErr != nil { - return "", "", &ServerGetConfigWithProfileError{Err: baseErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } if !base.FSM.HasTransition(fsm.HAS_CONFIG) { - return "", "", &fsm.FSMWrongStateTransitionError{Got: base.FSM.Current, Want: fsm.HAS_CONFIG} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: base.FSM.Current, Want: fsm.HAS_CONFIG}.CustomError()} } profile, profileErr := getCurrentProfile(server) if profileErr != nil { - return "", "", &ServerGetConfigWithProfileError{Err: profileErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr} } supportsOpenVPN := profile.supportsOpenVPN() @@ -400,7 +413,7 @@ func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) // If forceTCP we must be able to get a config with OpenVPN if forceTCP && supportsOpenVPN { - return "", "", &ServerGetConfigForceTCPError{} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerGetConfigForceTCPError{}} } var config string @@ -416,40 +429,42 @@ func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) } if configErr != nil { - return "", "", &ServerGetConfigWithProfileError{Err: configErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } return config, configType, nil } func askForProfileID(server Server) error { + errorMessage := "failed asking for a server profile ID" base, baseErr := server.GetBase() if baseErr != nil { - return &ServerAskForProfileIDError{Err: baseErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } if !base.FSM.HasTransition(fsm.ASK_PROFILE) { - return &fsm.FSMWrongStateTransitionError{Got: base.FSM.Current, Want: fsm.ASK_PROFILE} + return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: base.FSM.Current, Want: fsm.ASK_PROFILE}.CustomError()} } base.FSM.GoTransitionWithData(fsm.ASK_PROFILE, base.ProfilesRaw, false) return nil } func GetConfig(server Server, forceTCP bool) (string, string, error) { + errorMessage := "failed getting an OpenVPN/WireGuard configuration" base, baseErr := server.GetBase() if baseErr != nil { - return "", "", &ServerGetConfigError{Err: baseErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } if !base.FSM.InState(fsm.REQUEST_CONFIG) { - return "", "", &fsm.FSMWrongStateError{Got: base.FSM.Current, Want: fsm.REQUEST_CONFIG} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: base.FSM.Current, Want: fsm.REQUEST_CONFIG}.CustomError()} } // Get new profiles using the info call // This does not override the current profile infoErr := APIInfo(server) if infoErr != nil { - return "", "", &ServerGetConfigError{Err: infoErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: infoErr} } // If there was a profile chosen and it doesn't exist anymore, reset it @@ -473,7 +488,7 @@ func GetConfig(server Server, forceTCP bool) (string, string, error) { profileErr := askForProfileID(server) if profileErr != nil { - return "", "", &ServerGetConfigError{Err: profileErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: profileErr} } return getConfigWithProfile(server, forceTCP) @@ -487,14 +502,6 @@ func (e *ServerGetCurrentProfileNotFoundError) Error() string { return fmt.Sprintf("failed to get current profile, profile with ID: %s not found", e.ProfileID) } -type ServerGetConfigWithProfileError struct { - Err error -} - -func (e *ServerGetConfigWithProfileError) Error() string { - return fmt.Sprintf("failed to get config including profile with error %v", e.Err) -} - type ServerGetConfigForceTCPError struct{} func (e *ServerGetConfigForceTCPError) Error() string { @@ -507,28 +514,12 @@ func (e *ServerGetSecureInternetHomeError) Error() string { return "failed to get secure internet home server, not found" } -type ServerCopySecureInternetOAuthError struct { - Err error -} - -func (e *ServerCopySecureInternetOAuthError) Error() string { - return fmt.Sprintf("failed to copy oauth tokens from home server with error %v", e.Err) -} - type ServerEnsureServerEmptyURLError struct{} func (e *ServerEnsureServerEmptyURLError) Error() string { return "failed ensuring server, empty url provided" } -type ServerEnsureServerError struct { - Err error -} - -func (e *ServerEnsureServerError) Error() string { - return fmt.Sprintf("failed ensuring server with error %v", e.Err) -} - type ServerGetCurrentNoMapError struct{} func (e *ServerGetCurrentNoMapError) Error() string { @@ -541,31 +532,6 @@ func (e *ServerGetCurrentNotFoundError) Error() string { return "failed getting current server, not found" } -type ServerGetConfigError struct { - Err error -} - -func (e *ServerGetConfigError) Error() string { - return fmt.Sprintf("failed getting server config with error %v", e.Err) -} - -type ServerInitializeError struct { - URL string - Err error -} - -func (e *ServerInitializeError) Error() string { - return fmt.Sprintf("failed initializing server with url %s and error %v", e.URL, e.Err) -} - -type ServerInstituteBaseNotFoundError struct { - Err error -} - -func (e *ServerInstituteBaseNotFoundError) Error() string { - return "institute base not found" -} - type ServerSecureInternetMapNotFoundError struct{} func (e *ServerSecureInternetMapNotFoundError) Error() string { @@ -579,27 +545,3 @@ type ServerSecureInternetBaseNotFoundError struct { func (e *ServerSecureInternetBaseNotFoundError) Error() string { return fmt.Sprintf("secure internet base not found with current: %s", e.Current) } - -type ServerGetCurrentProfileError struct { - Err error -} - -func (e *ServerGetCurrentProfileError) Error() string { - return fmt.Sprintf("failed getting current profile with error: %v", e.Err) -} - -type ServerAskForProfileIDError struct { - Err error -} - -func (e *ServerAskForProfileIDError) Error() string { - return fmt.Sprintf("ask for profile ID error: %v", e.Err) -} - -type ServerEnsureTokensError struct { - Err error -} - -func (e *ServerEnsureTokensError) Error() string { - return fmt.Sprintf("failed ensuring tokens with error: %v", e.Err) -} diff --git a/internal/types/error.go b/internal/types/error.go new file mode 100644 index 0000000..fda7c9c --- /dev/null +++ b/internal/types/error.go @@ -0,0 +1,62 @@ +package types + +import ( + "errors" + "fmt" +) + +type WrappedErrorMessage struct { + Message string + Err error +} + +func (e *WrappedErrorMessage) Unwrap() error { + return e.Err +} + +func (e *WrappedErrorMessage) Cause() error { + causeErr := e.Err + for errors.Unwrap(causeErr) != nil { + causeErr = errors.Unwrap(causeErr) + } + return causeErr +} + +func (e *WrappedErrorMessage) Traceback() string { + returnStr := fmt.Sprintf("Traceback for error: %s", e.Message) + causeErr := e.Err + for errors.Unwrap(causeErr) != nil { + causeErr = errors.Unwrap(causeErr) + var wrappedErr *WrappedErrorMessage + + errorStr := causeErr.Error() + + if errors.As(causeErr, &wrappedErr) { + errorStr = wrappedErr.Message + } + returnStr += fmt.Sprintf("\n - %s", errorStr) + } + return returnStr +} + +func (e *WrappedErrorMessage) Error() string { + return fmt.Sprintf("Got error: %s, with cause: %s", e.Message, e.Err) +} + +func GetErrorTraceback(err error) string { + var wrappedErr *WrappedErrorMessage + + if errors.As(err, &wrappedErr) { + return wrappedErr.Traceback() + } + return err.Error() +} + +func GetErrorCause(err error) error { + var wrappedErr *WrappedErrorMessage + + if errors.As(err, &wrappedErr) { + return wrappedErr.Cause() + } + return err +} diff --git a/internal/util/util.go b/internal/util/util.go index 4bdd1b5..8dee61e 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -2,8 +2,11 @@ package util import ( "crypto/rand" + "fmt" "os" "time" + + "github.com/jwijenbergh/eduvpn-common/internal/types" ) // Creates a random byteslice of `size` @@ -11,7 +14,7 @@ func MakeRandomByteSlice(size int) ([]byte, error) { byteSlice := make([]byte, size) _, err := rand.Read(byteSlice) if err != nil { - return nil, err + return nil, &types.WrappedErrorMessage{Message: "failed reading random", Err: err} } return byteSlice, nil } @@ -24,7 +27,7 @@ func GenerateTimeSeconds() int64 { func EnsureDirectory(directory string) error { mkdirErr := os.MkdirAll(directory, os.ModePerm) if mkdirErr != nil { - return mkdirErr + return &types.WrappedErrorMessage{Message: fmt.Sprintf("failed to create directory %s", directory), Err: mkdirErr} } return nil } diff --git a/internal/verify/verify.go b/internal/verify/verify.go index 2d53b2e..b159297 100644 --- a/internal/verify/verify.go +++ b/internal/verify/verify.go @@ -6,6 +6,7 @@ import ( "os" "github.com/jedisct1/go-minisign" + "github.com/jwijenbergh/eduvpn-common/internal/types" ) // getKeys returns keys taken from https://git.sr.ht/~eduvpn/disco.eduvpn.org#public-keys. @@ -28,16 +29,19 @@ func getKeys() []string { // // Verify is a wrapper around verifyWithKeys where allowedPublicKeys is set to the list from https://git.sr.ht/~eduvpn/disco.eduvpn.org#public-keys. func Verify(signatureFileContent string, signedJson []byte, expectedFileName string, minSignTime uint64, forcePrehash bool) (bool, error) { + errorMessage := "failed signature verify" keyStrs := getKeys() if extraKey != "" { keyStrs = append(keyStrs, extraKey) _, err := fmt.Fprintf(os.Stderr, "INSECURE TEST MODE ENABLED WITH KEY %q\n", extraKey) + err = &types.WrappedErrorMessage{Message: errorMessage, Err: err} if err != nil { panic(err) } } valid, err := verifyWithKeys(signatureFileContent, signedJson, expectedFileName, minSignTime, keyStrs, forcePrehash) if err != nil { + err = &types.WrappedErrorMessage{Message: errorMessage, Err: err} var verifyCreatePublickeyError *VerifyCreatePublicKeyError if errors.As(err, &verifyCreatePublickeyError) { panic(err) // This should not happen unless keyStrs has an invalid key @@ -67,6 +71,7 @@ func InsecureTestingSetExtraKey(keyString string) { // The signature is checked to have a timestamp with a value of at least minSignTime, which is a UNIX timestamp without milliseconds. // // The return value will either be (true, nil) on success or (false, detailedVerifyError) on failure. +// Note that every error path is wrapped in a custom type here because minisign does not return custom error types, they use errors.New func verifyWithKeys(signatureFileContent string, signedJson []byte, filename string, minSignTime uint64, allowedPublicKeys []string, forcePrehash bool) (bool, error) { switch filename { case "server_list.json", "organization_list.json": @@ -143,6 +148,10 @@ func (e *VerifyInvalidSignatureFormatError) Error() string { return fmt.Sprintf("invalid signature format with error: %v", e.Err) } +func (e *VerifyInvalidSignatureFormatError) Unwrap() error { + return e.Err +} + type VerifyInvalidSignatureAlgorithmError struct { Algorithm string WantedAlgorithm string @@ -161,6 +170,10 @@ func (e *VerifyCreatePublicKeyError) Error() string { return fmt.Sprintf("failed to create public key: %s with error: %v", e.PublicKey, e.Err) } +func (e *VerifyCreatePublicKeyError) Unwrap() error { + return e.Err +} + type VerifyInvalidSignatureError struct { Err error } @@ -169,6 +182,10 @@ func (e *VerifyInvalidSignatureError) Error() string { return fmt.Sprintf("invalid signature with error: %v", e.Err) } +func (e *VerifyInvalidSignatureError) Unwrap() error { + return e.Err +} + type VerifyInvalidTrustedCommentError struct { TrustedComment string Err error @@ -178,6 +195,10 @@ func (e *VerifyInvalidTrustedCommentError) Error() string { return fmt.Sprintf("invalid trusted comment: %s with error: %v", e.TrustedComment, e.Err) } +func (e *VerifyInvalidTrustedCommentError) Unwrap() error { + return e.Err +} + type VerifyWrongSigFilenameError struct { Filename string SigFilename string diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index db20067..bb26b69 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -3,6 +3,8 @@ package wireguard import ( "fmt" "regexp" + + "github.com/jwijenbergh/eduvpn-common/internal/types" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -10,7 +12,7 @@ func GenerateKey() (wgtypes.Key, error) { key, keyErr := wgtypes.GeneratePrivateKey() if keyErr != nil { - return key, &WireguardGenerateKeyError{Err: keyErr} + return key, &types.WrappedErrorMessage{Message: "failed generating WireGuard key", Err: keyErr} } return key, nil } @@ -28,11 +30,3 @@ func ConfigAddKey(config string, key wgtypes.Key) string { to_replace := fmt.Sprintf("%s\nPrivateKey = %s", interface_section, key.String()) return interface_re.ReplaceAllString(config, to_replace) } - -type WireguardGenerateKeyError struct { - Err error -} - -func (e *WireguardGenerateKeyError) Error() string { - return fmt.Sprintf("failed generating Wireguard key with error: %v", e.Err) -} @@ -1,13 +1,12 @@ package eduvpn import ( - "fmt" - "github.com/jwijenbergh/eduvpn-common/internal/config" "github.com/jwijenbergh/eduvpn-common/internal/discovery" "github.com/jwijenbergh/eduvpn-common/internal/fsm" "github.com/jwijenbergh/eduvpn-common/internal/log" "github.com/jwijenbergh/eduvpn-common/internal/server" + "github.com/jwijenbergh/eduvpn-common/internal/types" ) type VPNState struct { @@ -31,8 +30,9 @@ type VPNState struct { } func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error { + errorMessage := "failed to register with the GO library" if !state.FSM.InState(fsm.DEREGISTERED) { - return &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.DEREGISTERED} + return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.DeregisteredError{}.CustomError()} } // Initialize the logger logLevel := log.LOG_WARNING @@ -43,7 +43,7 @@ func (state *VPNState) Register(name string, directory string, stateCallback fun loggerErr := state.Logger.Init(logLevel, name, directory) if loggerErr != nil { - return &StateRegisterError{Err: loggerErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: loggerErr} } // Initialize the FSM @@ -78,14 +78,15 @@ func (state *VPNState) Deregister() error { } func (state *VPNState) CancelOAuth() error { + errorMessage := "failed to cancel OAuth" if !state.FSM.InState(fsm.OAUTH_STARTED) { - return &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.OAUTH_STARTED} + return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: state.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()} } currentServer, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { - return &StateOAuthCancelError{Err: serverErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } server.CancelOAuth(currentServer) return nil @@ -96,7 +97,7 @@ func (state *VPNState) chooseServer(url string, isSecureInternet bool) (server.S server, serverErr := state.Servers.EnsureServer(url, isSecureInternet, &state.FSM, &state.Logger) if serverErr != nil { - return nil, serverErr + return nil, &types.WrappedErrorMessage{Message: "failed to choose server", Err: serverErr} } // Make sure we are in the chosen state if available @@ -105,20 +106,21 @@ func (state *VPNState) chooseServer(url string, isSecureInternet bool) (server.S } func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, forceTCP bool) (string, string, error) { + errorMessage := "failed to get a configuration for OpenVPN/Wireguard" if state.FSM.InState(fsm.DEREGISTERED) { - return "", "", &StateFSMNotRegisteredError{} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.DeregisteredError{}.CustomError()} } // Go to no server if possible, else return an error if !state.FSM.InState(fsm.NO_SERVER) && !state.FSM.GoTransition(fsm.NO_SERVER) { - return "", "", &fsm.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: fsm.NO_SERVER} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: state.FSM.Current, Want: fsm.NO_SERVER}.CustomError()} } // Make sure the server is chosen chosenServer, serverErr := state.chooseServer(url, isSecureInternet) if serverErr != nil { - return "", "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: serverErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } // Relogin with oauth // This moves the state to authorized @@ -129,7 +131,7 @@ func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, f // We are possibly in oauth started // So go to no server state.FSM.GoTransition(fsm.NO_SERVER) - return "", "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: loginErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: loginErr} } } else { // OAuth was valid, ensure we are in the authorized state state.FSM.GoTransition(fsm.AUTHORIZED) @@ -142,7 +144,7 @@ func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, f if configErr != nil { // Go back to no server if possible state.FSM.GoTransition(fsm.NO_SERVER) - return "", "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: configErr} + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: configErr} } else { state.FSM.GoTransition(fsm.HAS_CONFIG) } @@ -160,32 +162,33 @@ func (state *VPNState) GetConfigSecureInternet(url string, forceTCP bool) (strin func (state *VPNState) GetDiscoOrganizations() (string, error) { if state.FSM.InState(fsm.DEREGISTERED) { - return "", &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.DEREGISTERED} + return "", &types.WrappedErrorMessage{Message: "failed to get the organizations with Discovery", Err: fsm.DeregisteredError{}.CustomError()} } return state.Discovery.GetOrganizationsList() } func (state *VPNState) GetDiscoServers() (string, error) { if state.FSM.InState(fsm.DEREGISTERED) { - return "", &StateFSMNotRegisteredError{} + return "", &types.WrappedErrorMessage{Message: "failed to get the servers with Discovery", Err: fsm.DeregisteredError{}.CustomError()} } return state.Discovery.GetServersList() } func (state *VPNState) SetProfileID(profileID string) error { + errorMessage := "failed to set the profile ID for the current server" if !state.FSM.InState(fsm.ASK_PROFILE) { - return &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.ASK_PROFILE} + return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: state.FSM.Current, Want: fsm.ASK_PROFILE}.CustomError()} } server, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { - return &StateSetProfileError{ProfileID: profileID, Err: serverErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} } base, baseErr := server.GetBase() if baseErr != nil { - return &StateSetProfileError{ProfileID: profileID, Err: baseErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } base.Profiles.Current = profileID return nil @@ -193,7 +196,7 @@ func (state *VPNState) SetProfileID(profileID string) error { func (state *VPNState) SetConnected() error { if !state.FSM.HasTransition(fsm.CONNECTED) { - return &fsm.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: fsm.CONNECTED} + return fsm.WrongStateTransitionError{Got: state.FSM.Current, Want: fsm.CONNECTED}.CustomError() } state.FSM.GoTransition(fsm.CONNECTED) @@ -202,59 +205,17 @@ func (state *VPNState) SetConnected() error { func (state *VPNState) SetDisconnected() error { if !state.FSM.HasTransition(fsm.HAS_CONFIG) { - return &fsm.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: fsm.HAS_CONFIG} + return fsm.WrongStateTransitionError{Got: state.FSM.Current, Want: fsm.HAS_CONFIG}.CustomError() } state.FSM.GoTransition(fsm.HAS_CONFIG) return nil } -type StateSetProfileError struct { - ProfileID string - Err error -} - -func (e *StateSetProfileError) Error() string { - return fmt.Sprintf("failed to set profile ID: %s with error: %v", e.ProfileID, e.Err) -} - -type StateRegisterError struct { - Err error -} - -func (e *StateRegisterError) Error() string { - return fmt.Sprintf("failed to register with error: %v", e.Err) -} - -type StateFSMNotRegisteredError struct{} - -func (e *StateFSMNotRegisteredError) Error() string { - return fmt.Sprintf("state is not registered. Current FSM state: %s", fsm.DEREGISTERED.String()) -} - -type StateWrongFSMStateError struct { - Got fsm.FSMStateID - Want fsm.FSMStateID -} - -func (e *StateWrongFSMStateError) Error() string { - return fmt.Sprintf("wrong FSM state, got: %s, want: %s", e.Got.String(), e.Want.String()) -} - -type StateOAuthCancelError struct { - Err error -} - -func (e *StateOAuthCancelError) Error() string { - return fmt.Sprintf("failed cancelling OAuth for state with error: %v", e.Err) -} - -type StateConnectError struct { - URL string - IsSecureInternet bool - Err error +func (state *VPNState) GetErrorTraceback(err error) string { + return types.GetErrorTraceback(err) } -func (e *StateConnectError) Error() string { - return fmt.Sprintf("failed connecting to server: %s (is secure internet: %v) with error: %v", e.URL, e.IsSecureInternet, e.Err) +func (state *VPNState) GetErrorCause(err error) error { + return types.GetErrorCause(err) } diff --git a/state_test.go b/state_test.go index 521513a..76f4763 100644 --- a/state_test.go +++ b/state_test.go @@ -15,6 +15,7 @@ import ( httpw "github.com/jwijenbergh/eduvpn-common/internal/http" "github.com/jwijenbergh/eduvpn-common/internal/oauth" "github.com/jwijenbergh/eduvpn-common/internal/server" + "github.com/jwijenbergh/eduvpn-common/internal/types" ) func ensureLocalWellKnown() { @@ -96,34 +97,16 @@ func test_connect_oauth_parameter(t *testing.T, parameters httpw.URLParameters, }, false) _, _, configErr := state.GetConfigInstituteAccess(serverURI, false) - var stateErr *StateConnectError - var loginErr *oauth.OAuthLoginError - var finishErr *oauth.OAuthFinishError + var wrappedErr *types.WrappedErrorMessage - // We go through the chain of errors by unwrapping them one by one - - // First ensure we get a state connect error - if !errors.As(configErr, &stateErr) { - t.Fatalf("error %T = %v, wantErr %T", configErr, configErr, stateErr) - } - - // Then ensure we get a login error - gotLoginErr := stateErr.Err - - if !errors.As(gotLoginErr, &loginErr) { - t.Fatalf("error %T = %v, wantErr %T", gotLoginErr, gotLoginErr, loginErr) - } - - // Then ensure we get a finish error - gotFinishErr := loginErr.Err - - if !errors.As(gotFinishErr, &finishErr) { - t.Fatalf("error %T = %v, wantErr %T", gotFinishErr, gotFinishErr, finishErr) + // We ensure the error is of a wrappedErrorMessage + if !errors.As(configErr, &wrappedErr) { + t.Fatalf("error %T = %v, wantErr %T", configErr, configErr, wrappedErr) } - // Then ensure we get the expected inner error - gotExpectedErr := finishErr.Err + gotExpectedErr := wrappedErr.Cause() + // Then we check if the cause is correct if !errors.As(gotExpectedErr, expectedErr) { t.Fatalf("error %T = %v, wantErr %T", gotExpectedErr, gotExpectedErr, expectedErr) } |
