summaryrefslogtreecommitdiff
path: root/internal/oauth
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-10-19 16:51:48 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-10-19 17:05:59 +0200
commit7260aa0cd70195a4679ca3c94204d9e618f947f2 (patch)
tree9321f5f3d21b06d1ab6dd50420879bc5ea41f044 /internal/oauth
parentf1a265190d8fd862bfff680fd0937a7f99759955 (diff)
Refactor: Make errors use the parent's error level
- All wrapped errors have to be created with types.NewWrappedError to inherit the error level from the parent - Or types.NewWrappedErrorLevel can be used which means a custom error level is given. For example this is done with cancelling OAuth - Client public errors are forwarded with handleError that also logs it with the error's level
Diffstat (limited to 'internal/oauth')
-rw-r--r--internal/oauth/oauth.go121
1 files changed, 61 insertions, 60 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 6ac773c..df29a9c 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -28,7 +28,7 @@ import (
func genState() (string, error) {
randomBytes, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth state", Err: err}
+ return "", types.NewWrappedError("failed generating an OAuth state", err)
}
// For consistency we also use raw url encoding here
@@ -61,10 +61,10 @@ func genChallengeS256(verifier string) string {
func genVerifier() (string, error) {
randomBytes, err := util.MakeRandomByteSlice(32)
if err != nil {
- return "", &types.WrappedErrorMessage{
- Message: "failed generating an OAuth verifier",
- Err: err,
- }
+ return "", types.NewWrappedError(
+ "failed generating an OAuth verifier",
+ err,
+ )
}
return base64.RawURLEncoding.EncodeToString(randomBytes), nil
@@ -111,7 +111,7 @@ func (oauth *OAuth) setupListener() error {
// create a listener
listener, listenerErr := net.Listen("tcp", ":0")
if listenerErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: listenerErr}
+ return types.NewWrappedError(errorMessage, listenerErr)
}
oauth.Session.Listener = listener
return nil
@@ -120,7 +120,7 @@ func (oauth *OAuth) setupListener() error {
func (oauth *OAuth) getTokensWithCallback() error {
errorMessage := "failed getting tokens with callback"
if oauth.Session.Listener == nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No listener")}
+ return types.NewWrappedError(errorMessage, errors.New("No listener"))
}
mux := http.NewServeMux()
// server /callback over the listener address
@@ -130,7 +130,7 @@ func (oauth *OAuth) getTokensWithCallback() error {
mux.HandleFunc("/callback", oauth.Callback)
if err := oauth.Session.Server.Serve(oauth.Session.Listener); err != http.ErrServerClosed {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: err}
+ return types.NewWrappedError(errorMessage, err)
}
return oauth.Session.CallbackError
}
@@ -146,7 +146,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
port, portErr := oauth.GetListenerPort()
if portErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: portErr}
+ return types.NewWrappedError(errorMessage, portErr)
}
data := url.Values{
@@ -163,7 +163,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
current_time := util.GetCurrentTime()
_, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
+ return types.NewWrappedError(errorMessage, bodyErr)
}
tokenStructure := OAuthToken{}
@@ -171,10 +171,10 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
jsonErr := json.Unmarshal(body, &tokenStructure)
if jsonErr != nil {
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr},
- }
+ return types.NewWrappedError(
+ errorMessage,
+ &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr},
+ )
}
tokenStructure.ExpiredTimestamp = current_time.Add(
@@ -207,17 +207,17 @@ func (oauth *OAuth) getTokensWithRefresh() error {
current_time := util.GetCurrentTime()
_, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts)
if bodyErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr}
+ return types.NewWrappedError(errorMessage, bodyErr)
}
tokenStructure := OAuthToken{}
jsonErr := json.Unmarshal(body, &tokenStructure)
if jsonErr != nil {
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr},
- }
+ return types.NewWrappedError(
+ errorMessage,
+ &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr},
+ )
}
tokenStructure.ExpiredTimestamp = current_time.Add(
@@ -273,7 +273,7 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro
errorMessage := "failed writing response HTML"
template, templateErr := template.New("oauth-response").Parse(responseTemplate)
if templateErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: templateErr}
+ return types.NewWrappedError(errorMessage, templateErr)
}
executeErr := template.Execute(w, oauthResponseHTML{
@@ -281,7 +281,7 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro
Message: message,
})
if executeErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: executeErr}
+ return types.NewWrappedError(errorMessage, executeErr)
}
return nil
}
@@ -310,10 +310,10 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
extractedISS := urlQuery.Get("iss")
if extractedISS != "" {
if oauth.Session.ISS != extractedISS {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.Session.ISS},
- }
+ oauth.Session.CallbackError = types.NewWrappedError(
+ errorMessage,
+ &OAuthCallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.Session.ISS},
+ )
return
}
@@ -323,31 +323,31 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
extractedState := urlQuery.Get("state")
if extractedState == "" {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()},
- }
+ oauth.Session.CallbackError = types.NewWrappedError(
+ errorMessage,
+ &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()},
+ )
return
}
// The state is the first entry
if extractedState != oauth.Session.State {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthCallbackStateMatchError{
+ oauth.Session.CallbackError = types.NewWrappedError(
+ errorMessage,
+ &OAuthCallbackStateMatchError{
State: extractedState,
ExpectedState: oauth.Session.State,
},
- }
+ )
return
}
// No authorization code
extractedCode := urlQuery.Get("code")
if extractedCode == "" {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()},
- }
+ oauth.Session.CallbackError = types.NewWrappedError(
+ errorMessage,
+ &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()},
+ )
return
}
@@ -355,10 +355,10 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
// Obtaining the access and refresh tokens
getTokensErr := oauth.getTokensWithAuthCode(extractedCode)
if getTokensErr != nil {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: getTokensErr,
- }
+ oauth.Session.CallbackError = types.NewWrappedError(
+ errorMessage,
+ getTokensErr,
+ )
return
}
}
@@ -372,7 +372,7 @@ func (oauth OAuth) GetListenerPort() (int, error) {
errorMessage := "failed to get listener port"
if oauth.Session.Listener == nil {
- return 0, &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No OAuth listener")}
+ return 0, types.NewWrappedError(errorMessage, errors.New("No OAuth listener"))
}
return oauth.Session.Listener.Addr().(*net.TCPAddr).Port, nil
}
@@ -384,14 +384,14 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str
// Generate the verifier and challenge
verifier, verifierErr := genVerifier()
if verifierErr != nil {
- return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr}
+ return "", types.NewWrappedError(errorMessage, verifierErr)
}
challenge := genChallengeS256(verifier)
// Generate the state
state, stateErr := genState()
if stateErr != nil {
- return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr}
+ return "", types.NewWrappedError(errorMessage, stateErr)
}
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
@@ -401,13 +401,13 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str
// set up the listener to get the redirect URI
listenerErr := oauth.setupListener()
if listenerErr != nil {
- return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr}
+ return "", types.NewWrappedError(errorMessage, stateErr)
}
// Get the listener port
port, portErr := oauth.GetListenerPort()
if portErr != nil {
- return "", &types.WrappedErrorMessage{Message: errorMessage, Err: portErr}
+ return "", types.NewWrappedError(errorMessage, portErr)
}
parameters := map[string]string{
@@ -423,7 +423,7 @@ func (oauth *OAuth) GetAuthURL(name string, iss string, postProcessAuth func(str
authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters)
if urlErr != nil {
- return "", &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
+ return "", types.NewWrappedError(errorMessage, urlErr)
}
// Return the url processed
@@ -435,16 +435,17 @@ func (oauth *OAuth) Exchange() error {
tokenErr := oauth.getTokensWithCallback()
if tokenErr != nil {
- return &types.WrappedErrorMessage{Message: "failed finishing OAuth", Err: tokenErr}
+ return types.NewWrappedError("failed finishing OAuth", tokenErr)
}
return nil
}
func (oauth *OAuth) Cancel() {
- oauth.Session.CallbackError = &types.WrappedErrorMessage{
- Message: "cancelled OAuth",
- Err: &OAuthCancelledCallbackError{},
- }
+ oauth.Session.CallbackError = types.NewWrappedErrorLevel(
+ types.ERR_INFO,
+ "cancelled OAuth",
+ &OAuthCancelledCallbackError{},
+ )
if oauth.Session.Server != nil {
oauth.Session.Server.Shutdown(oauth.Session.Context) //nolint:errcheck
}
@@ -454,10 +455,10 @@ func (oauth *OAuth) EnsureTokens() error {
errorMessage := "failed ensuring OAuth tokens"
// Access Token or Refresh Tokens empty, we can not ensure the tokens
if oauth.Token.Access == "" && oauth.Token.Refresh == "" {
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthTokensInvalidError{Cause: "tokens are empty"},
- }
+ return types.NewWrappedError(
+ errorMessage,
+ &OAuthTokensInvalidError{Cause: "tokens are empty"},
+ )
}
// We have tokens...
@@ -472,12 +473,12 @@ func (oauth *OAuth) EnsureTokens() error {
// We have obtained new tokens with refresh
if refreshErr != nil {
// We have failed to ensure the tokens due to refresh not working
- return &types.WrappedErrorMessage{
- Message: errorMessage,
- Err: &OAuthTokensInvalidError{
+ return types.NewWrappedError(
+ errorMessage,
+ &OAuthTokensInvalidError{
Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr),
},
- }
+ )
}
return nil