diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-10-19 16:51:48 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-10-19 17:05:59 +0200 |
| commit | 7260aa0cd70195a4679ca3c94204d9e618f947f2 (patch) | |
| tree | 9321f5f3d21b06d1ab6dd50420879bc5ea41f044 /internal/oauth | |
| parent | f1a265190d8fd862bfff680fd0937a7f99759955 (diff) | |
Refactor: Make errors use the parent's error level
- All wrapped errors have to be created with types.NewWrappedError to
inherit the error level from the parent
- Or types.NewWrappedErrorLevel can be used which means a custom error
level is given. For example this is done with cancelling OAuth
- Client public errors are forwarded with handleError that also logs
it with the error's level
Diffstat (limited to 'internal/oauth')
| -rw-r--r-- | internal/oauth/oauth.go | 121 |
1 files changed, 61 insertions, 60 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 6ac773c..df29a9c 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -28,7 +28,7 @@ import ( func genState() (string, error) { randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { - return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth state", Err: err} + return "", types.NewWrappedError("failed generating an OAuth state", err) } // For consistency we also use raw url encoding here @@ -61,10 +61,10 @@ func genChallengeS256(verifier string) string { func genVerifier() (string, error) { randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { - return "", &types.WrappedErrorMessage{ - Message: "failed generating an OAuth verifier", - Err: err, - } + return "", types.NewWrappedError( + "failed generating an OAuth verifier", + err, + ) } return base64.RawURLEncoding.EncodeToString(randomBytes), nil @@ -111,7 +111,7 @@ func (oauth *OAuth) setupListener() error { // create a listener listener, listenerErr := net.Listen("tcp", ":0") if listenerErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: listenerErr} + return types.NewWrappedError(errorMessage, listenerErr) } oauth.Session.Listener = listener return nil @@ -120,7 +120,7 @@ func (oauth *OAuth) setupListener() error { func (oauth *OAuth) getTokensWithCallback() error { errorMessage := "failed getting tokens with callback" if oauth.Session.Listener == nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No listener")} + return types.NewWrappedError(errorMessage, errors.New("No listener")) } mux := http.NewServeMux() // server /callback over the listener address @@ -130,7 +130,7 @@ func (oauth *OAuth) getTokensWithCallback() error { mux.HandleFunc("/callback", oauth.Callback) if err := oauth.Session.Server.Serve(oauth.Session.Listener); err != http.ErrServerClosed { - return &types.WrappedErrorMessage{Message: errorMessage, Err: err} + return types.NewWrappedError(errorMessage, err) } return oauth.Session.CallbackError } @@ -146,7 +146,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { port, portErr := oauth.GetListenerPort() if portErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: portErr} + return types.NewWrappedError(errorMessage, portErr) } data := url.Values{ @@ -163,7 +163,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { current_time := util.GetCurrentTime() _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} + return types.NewWrappedError(errorMessage, bodyErr) } tokenStructure := OAuthToken{} @@ -171,10 +171,10 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}, - } + return types.NewWrappedError( + errorMessage, + &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}, + ) } tokenStructure.ExpiredTimestamp = current_time.Add( @@ -207,17 +207,17 @@ func (oauth *OAuth) getTokensWithRefresh() error { current_time := util.GetCurrentTime() _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} + return types.NewWrappedError(errorMessage, bodyErr) } tokenStructure := OAuthToken{} jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}, - } + return types.NewWrappedError( + errorMessage, + &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}, + ) } tokenStructure.ExpiredTimestamp = current_time.Add( @@ -273,7 +273,7 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro errorMessage := "failed writing response HTML" template, templateErr := template.New("oauth-response").Parse(responseTemplate) if templateErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: templateErr} + return types.NewWrappedError(errorMessage, templateErr) } executeErr := template.Execute(w, oauthResponseHTML{ @@ -281,7 +281,7 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro Message: message, }) if executeErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: executeErr} + return types.NewWrappedError(errorMessage, executeErr) } return nil } @@ -310,10 +310,10 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { extractedISS := urlQuery.Get("iss") if extractedISS != "" { if oauth.Session.ISS != extractedISS { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.Session.ISS}, - } + oauth.Session.CallbackError = types.NewWrappedError( + errorMessage, + &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.Session.ISS}, + ) return } @@ -323,31 +323,31 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 extractedState := urlQuery.Get("state") if extractedState == "" { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}, - } + oauth.Session.CallbackError = types.NewWrappedError( + errorMessage, + &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}, + ) return } // The state is the first entry if extractedState != oauth.Session.State { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthCallbackStateMatchError{ + oauth.Session.CallbackError = types.NewWrappedError( + errorMessage, + &OAuthCallbackStateMatchError{ State: extractedState, ExpectedState: oauth.Session.State, }, - } + ) return } // No authorization code extractedCode := urlQuery.Get("code") if extractedCode == "" { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}, - } + oauth.Session.CallbackError = types.NewWrappedError( + errorMessage, + &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}, + ) return } @@ -355,10 +355,10 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // Obtaining the access and refresh tokens getTokensErr := oauth.getTokensWithAuthCode(extractedCode) if getTokensErr != nil { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: errorMessage, - Err: getTokensErr, - } + oauth.Session.CallbackError = types.NewWrappedError( + errorMessage, + getTokensErr, + ) return } } @@ -372,7 +372,7 @@ func (oauth OAuth) GetListenerPort() (int, error) { errorMessage := "failed to get listener port" if oauth.Session.Listener == nil { - return 0, &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No OAuth listener")} + return 0, types.NewWrappedError(errorMessage, errors.New("No OAuth listener")) } return oauth.Session.Listener.Addr().(*net.TCPAddr).Port, nil } @@ -384,14 +384,14 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str // Generate the verifier and challenge verifier, verifierErr := genVerifier() if verifierErr != nil { - return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr} + return "", types.NewWrappedError(errorMessage, verifierErr) } challenge := genChallengeS256(verifier) // Generate the state state, stateErr := genState() if stateErr != nil { - return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr} + return "", types.NewWrappedError(errorMessage, stateErr) } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client @@ -401,13 +401,13 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str // set up the listener to get the redirect URI listenerErr := oauth.setupListener() if listenerErr != nil { - return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr} + return "", types.NewWrappedError(errorMessage, stateErr) } // Get the listener port port, portErr := oauth.GetListenerPort() if portErr != nil { - return "", &types.WrappedErrorMessage{Message: errorMessage, Err: portErr} + return "", types.NewWrappedError(errorMessage, portErr) } parameters := map[string]string{ @@ -423,7 +423,7 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters) if urlErr != nil { - return "", &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} + return "", types.NewWrappedError(errorMessage, urlErr) } // Return the url processed @@ -435,16 +435,17 @@ func (oauth *OAuth) Exchange() error { tokenErr := oauth.getTokensWithCallback() if tokenErr != nil { - return &types.WrappedErrorMessage{Message: "failed finishing OAuth", Err: tokenErr} + return types.NewWrappedError("failed finishing OAuth", tokenErr) } return nil } func (oauth *OAuth) Cancel() { - oauth.Session.CallbackError = &types.WrappedErrorMessage{ - Message: "cancelled OAuth", - Err: &OAuthCancelledCallbackError{}, - } + oauth.Session.CallbackError = types.NewWrappedErrorLevel( + types.ERR_INFO, + "cancelled OAuth", + &OAuthCancelledCallbackError{}, + ) if oauth.Session.Server != nil { oauth.Session.Server.Shutdown(oauth.Session.Context) //nolint:errcheck } @@ -454,10 +455,10 @@ func (oauth *OAuth) EnsureTokens() error { errorMessage := "failed ensuring OAuth tokens" // Access Token or Refresh Tokens empty, we can not ensure the tokens if oauth.Token.Access == "" && oauth.Token.Refresh == "" { - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthTokensInvalidError{Cause: "tokens are empty"}, - } + return types.NewWrappedError( + errorMessage, + &OAuthTokensInvalidError{Cause: "tokens are empty"}, + ) } // We have tokens... @@ -472,12 +473,12 @@ func (oauth *OAuth) EnsureTokens() error { // We have obtained new tokens with refresh if refreshErr != nil { // We have failed to ensure the tokens due to refresh not working - return &types.WrappedErrorMessage{ - Message: errorMessage, - Err: &OAuthTokensInvalidError{ + return types.NewWrappedError( + errorMessage, + &OAuthTokensInvalidError{ Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr), }, - } + ) } return nil |
