diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-05-02 14:34:35 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-05-02 14:34:35 +0200 |
| commit | 466450f0c47bdc614e66326d90e5fc6fb56ae732 (patch) | |
| tree | a01518a58d50d2f8449d37dadecc40e35c9f1fe1 /internal/oauth.go | |
| parent | a2a8efdcaad3d9b1852b1367a7cd7e8c5860cecf (diff) | |
Refactor: Wrap most errors in a custom type
Diffstat (limited to 'internal/oauth.go')
| -rw-r--r-- | internal/oauth.go | 113 |
1 files changed, 72 insertions, 41 deletions
diff --git a/internal/oauth.go b/internal/oauth.go index 98af5a4..c13ea99 100644 --- a/internal/oauth.go +++ b/internal/oauth.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -20,7 +19,7 @@ import ( func genState() (string, error) { randomBytes, err := MakeRandomByteSlice(32) if err != nil { - return "", &OAuthGenStateUnableError{Err: err} + return "", &OAuthGenStateError{Err: err} } // For consistency we also use raw url encoding here @@ -46,7 +45,7 @@ func genChallengeS256(verifier string) string { func genVerifier() (string, error) { randomBytes, err := MakeRandomByteSlice(32) if err != nil { - return "", &OAuthGenVerifierUnableError{Err: err} + return "", &OAuthGenVerifierError{Err: err} } return base64.RawURLEncoding.EncodeToString(randomBytes), nil @@ -95,7 +94,7 @@ func (oauth *OAuth) getTokensWithCallback() error { } mux.HandleFunc("/callback", oauth.Callback) if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed { - return &OAuthFailedCallbackError{Addr: addr, Err: err} + return &OAuthCallbackError{Addr: addr, Err: err} } return oauth.Session.CallbackError } @@ -122,7 +121,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { current_time := GenerateTimeSeconds() _, body, bodyErr := HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return bodyErr + return &OAuthAuthError{Err: bodyErr} } tokenStructure := OAuthToken{} @@ -160,7 +159,7 @@ func (oauth *OAuth) getTokensWithRefresh() error { current_time := GenerateTimeSeconds() _, body, bodyErr := HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return bodyErr + return &OAuthRefreshError{Err: bodyErr} } tokenStructure := OAuthToken{} @@ -181,7 +180,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // Extract the authorization code code, success := req.URL.Query()["code"] if !success { - oauth.Session.CallbackError = &OAuthFailedCallbackParameterError{Parameter: "code", URL: req.URL.String()} + oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()} go oauth.Session.Server.Shutdown(oauth.Session.Context) return } @@ -192,14 +191,14 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 state, success := req.URL.Query()["state"] if !success { - oauth.Session.CallbackError = &OAuthFailedCallbackParameterError{Parameter: "state", URL: req.URL.String()} + oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()} go oauth.Session.Server.Shutdown(oauth.Session.Context) return } // The state is the first entry extractedState := state[0] if extractedState != oauth.Session.State { - oauth.Session.CallbackError = &OAuthFailedCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State} + oauth.Session.CallbackError = &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State} go oauth.Session.Server.Shutdown(oauth.Session.Context) return } @@ -208,7 +207,7 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // Obtaining the access and refresh tokens err := oauth.getTokensWithAuthCode(extractedCode) if err != nil { - oauth.Session.CallbackError = &OAuthFailedCallbackGetTokensError{Err: err} + oauth.Session.CallbackError = &OAuthCallbackGetTokensError{Err: err} go oauth.Session.Server.Shutdown(oauth.Session.Context) return } @@ -225,18 +224,18 @@ func (oauth *OAuth) Init(fsm *FSM, logger *FileLogger) { // Starts the OAuth exchange for eduvpn. func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) error { if !oauth.FSM.HasTransition(OAUTH_STARTED) { - return errors.New(fmt.Sprintf("Failed starting oauth, invalid state %s", oauth.FSM.Current.String())) + return &FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: OAUTH_STARTED} } // Generate the state state, stateErr := genState() if stateErr != nil { - return &OAuthFailedInitializeError{Err: stateErr} + return &OAuthInitializeError{Err: stateErr} } // Generate the verifier and challenge verifier, verifierErr := genVerifier() if verifierErr != nil { - return &OAuthFailedInitializeError{Err: verifierErr} + return &OAuthInitializeError{Err: verifierErr} } challenge := genChallengeS256(verifier) @@ -252,8 +251,8 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) authURL, urlErr := HTTPConstructURL(authorizationURL, parameters) - if urlErr != nil { // shouldn't happen - panic(urlErr) + if urlErr != nil { + return &OAuthInitializeError{Err: urlErr} } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client @@ -268,12 +267,12 @@ func (oauth *OAuth) start(name string, authorizationURL string, tokenURL string) // Error definitions func (oauth *OAuth) Finish() error { if !oauth.FSM.HasTransition(AUTHORIZED) { - return errors.New("invalid state to finish oauth") + return &FSMWrongStateError{Got: oauth.FSM.Current, Want: AUTHORIZED} } tokenErr := oauth.getTokensWithCallback() if tokenErr != nil { - return tokenErr + return &OAuthFinishError{Err: tokenErr} } oauth.FSM.GoTransition(AUTHORIZED) return nil @@ -288,13 +287,13 @@ func (oauth *OAuth) Login(name string, authorizationURL string, tokenURL string) authInitializeErr := oauth.start(name, authorizationURL, tokenURL) if authInitializeErr != nil { - return authInitializeErr + return &OAuthLoginError{Err: authInitializeErr} } oauthErr := oauth.Finish() if oauthErr != nil { - return oauthErr + return &OAuthLoginError{Err: oauthErr} } return nil } @@ -329,64 +328,96 @@ func (oauth *OAuth) NeedsRelogin() bool { type OAuthCancelledCallbackError struct{} func (e *OAuthCancelledCallbackError) Error() string { - return fmt.Sprintf("Client cancelled OAuth") + return fmt.Sprintf("client cancelled OAuth") } -type OAuthGenStateUnableError struct { +type OAuthGenStateError struct { Err error } -func (e *OAuthGenStateUnableError) Error() string { - return fmt.Sprintf("failed generating state with error %v", e.Err) +func (e *OAuthGenStateError) Error() string { + return fmt.Sprintf("failed generating state with error: %v", e.Err) } -type OAuthGenVerifierUnableError struct { +type OAuthGenVerifierError struct { Err error } -func (e *OAuthGenVerifierUnableError) Error() string { - return fmt.Sprintf("failed generating verifier with error %v", e.Err) +func (e *OAuthGenVerifierError) Error() string { + return fmt.Sprintf("failed generating verifier with error: %v", e.Err) } -type OAuthFailedCallbackError struct { +type OAuthCallbackError struct { Addr string Err error } -func (e *OAuthFailedCallbackError) Error() string { - return fmt.Sprintf("failed callback %s with error %v", e.Addr, e.Err) +func (e *OAuthCallbackError) Error() string { + return fmt.Sprintf("failed callback: %s with error: %v", e.Addr, e.Err) } -type OAuthFailedCallbackParameterError struct { +type OAuthCallbackParameterError struct { Parameter string URL string } -func (e *OAuthFailedCallbackParameterError) Error() string { - return fmt.Sprintf("failed retrieving parameter %s in url %s", e.Parameter, e.URL) +func (e *OAuthCallbackParameterError) Error() string { + return fmt.Sprintf("failed retrieving parameter: %s in url: %s", e.Parameter, e.URL) } -type OAuthFailedCallbackStateMatchError struct { +type OAuthCallbackStateMatchError struct { State string ExpectedState string } -func (e *OAuthFailedCallbackStateMatchError) Error() string { - return fmt.Sprintf("failed matching state, got %s, want %s", e.State, e.ExpectedState) +func (e *OAuthCallbackStateMatchError) Error() string { + return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState) } -type OAuthFailedCallbackGetTokensError struct { +type OAuthCallbackGetTokensError struct { Err error } -func (e *OAuthFailedCallbackGetTokensError) Error() string { - return fmt.Sprintf("failed getting tokens with error %v", e.Err) +func (e *OAuthCallbackGetTokensError) Error() string { + return fmt.Sprintf("failed getting tokens with error: %v", e.Err) } -type OAuthFailedInitializeError struct { +type OAuthFinishError struct { Err error } -func (e *OAuthFailedInitializeError) Error() string { - return fmt.Sprintf("failed initializing OAuth with error %v", e.Err) +func (e *OAuthFinishError) Error() string { + return fmt.Sprintf("failed finishing OAuth with error: %v", e.Err) +} + +type OAuthLoginError struct { + Err error +} + +func (e *OAuthLoginError) Error() string { + return fmt.Sprintf("failed OAuth logging in with error: %v", e.Err) +} + +type OAuthInitializeError struct { + Err error +} + +func (e *OAuthInitializeError) Error() string { + return fmt.Sprintf("failed initializing OAuth with error: %v", e.Err) +} + +type OAuthAuthError struct { + Err error +} + +func (e *OAuthAuthError) Error() string { + return fmt.Sprintf("failed getting tokens with auth code for OAuth with error: %v", e.Err) +} + +type OAuthRefreshError struct { + Err error +} + +func (e *OAuthRefreshError) Error() string { + return fmt.Sprintf("failed refreshing tokens for OAuth with error: %v", e.Err) } |
