diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-06-20 15:20:18 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-06-20 15:20:18 +0200 |
| commit | 2252135fadb8c579ad27345e3203be755130e3cd (patch) | |
| tree | ed5a530e85b43736fc0bc28c927cfa8488f9199b /internal/oauth | |
| parent | 7af07c596166bf93b79a9d0816b1950dde360fb9 (diff) | |
Refactor: Errors to have one custom type that is to be wrapped
- For this an `internal/types` package is created with a custom error type
- This custom error type can give back the cause and traceback of an error
Diffstat (limited to 'internal/oauth')
| -rw-r--r-- | internal/oauth/oauth.go | 144 |
1 files changed, 38 insertions, 106 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index f6ed916..824db90 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -8,10 +8,12 @@ import ( "fmt" "net/http" "net/url" + "github.com/jwijenbergh/eduvpn-common/internal/fsm" httpw "github.com/jwijenbergh/eduvpn-common/internal/http" - "github.com/jwijenbergh/eduvpn-common/internal/util" "github.com/jwijenbergh/eduvpn-common/internal/log" + "github.com/jwijenbergh/eduvpn-common/internal/types" + "github.com/jwijenbergh/eduvpn-common/internal/util" ) // Generates a random base64 string to be used for state @@ -23,7 +25,7 @@ import ( func genState() (string, error) { randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { - return "", &OAuthGenStateError{Err: err} + return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth state", Err: err} } // For consistency we also use raw url encoding here @@ -49,7 +51,7 @@ func genChallengeS256(verifier string) string { func genVerifier() (string, error) { randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { - return "", &OAuthGenVerifierError{Err: err} + return "", &types.WrappedErrorMessage{Message: "failed generating an OAuth verifier", Err: err} } return base64.RawURLEncoding.EncodeToString(randomBytes), nil @@ -60,8 +62,8 @@ type OAuth struct { Token OAuthToken `json:"token"` BaseAuthorizationURL string `json:"base_authorization_url"` TokenURL string `json:"token_url"` - Logger *log.FileLogger `json:"-"` - FSM *fsm.FSM `json:"-"` + Logger *log.FileLogger `json:"-"` + FSM *fsm.FSM `json:"-"` } // This structure gets passed to the callback for easy access to the current state @@ -99,7 +101,7 @@ func (oauth *OAuth) getTokensWithCallback() error { } mux.HandleFunc("/callback", oauth.Callback) if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed { - return &OAuthCallbackError{Addr: addr, Err: err} + return &types.WrappedErrorMessage{Message: "failed getting tokens with callback", Err: err} } return oauth.Session.CallbackError } @@ -108,9 +110,9 @@ func (oauth *OAuth) getTokensWithCallback() error { // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { + errorMessage := "failed getting tokens with the authorization code" // Make sure the verifier is set as the parameter // so that the server can verify that we are the actual owner of the authorization code - reqURL := oauth.TokenURL data := url.Values{ "client_id": {oauth.Session.ClientID}, @@ -126,7 +128,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { current_time := util.GenerateTimeSeconds() _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return &OAuthAuthError{Err: bodyErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} } tokenStructure := OAuthToken{} @@ -134,7 +136,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { - return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}} } tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires @@ -152,6 +154,7 @@ func (oauth *OAuth) isTokensExpired() bool { // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 func (oauth *OAuth) getTokensWithRefresh() error { + errorMessage := "failed getting tokens with the refresh token" reqURL := oauth.TokenURL data := url.Values{ "refresh_token": {oauth.Token.Refresh}, @@ -164,14 +167,14 @@ func (oauth *OAuth) getTokensWithRefresh() error { current_time := util.GenerateTimeSeconds() _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { - return &OAuthRefreshError{Err: bodyErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: bodyErr} } tokenStructure := OAuthToken{} jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { - return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}} } tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires @@ -182,11 +185,15 @@ func (oauth *OAuth) getTokensWithRefresh() error { // //// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { + errorMessage := "failed callback to retrieve the authorization code" // Extract the authorization code code, success := req.URL.Query()["code"] - if !success { - oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()} + // Shutdown after we're done + defer func() { go oauth.Session.Server.Shutdown(oauth.Session.Context) + }() + if !success { + oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackParameterError{Parameter: "code", URL: req.URL.String()}} return } // The code is the first entry @@ -195,30 +202,25 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // Make sure the state is present and matches to protect against cross-site request forgeries // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 state, success := req.URL.Query()["state"] + if !success { - oauth.Session.CallbackError = &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()} - go oauth.Session.Server.Shutdown(oauth.Session.Context) + oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackParameterError{Parameter: "state", URL: req.URL.String()}} return } // The state is the first entry extractedState := state[0] if extractedState != oauth.Session.State { - oauth.Session.CallbackError = &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State} - go oauth.Session.Server.Shutdown(oauth.Session.Context) + oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: &OAuthCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State}} return } // Now that we have obtained the authorization code, we can move to the next step: // Obtaining the access and refresh tokens - err := oauth.getTokensWithAuthCode(extractedCode) - if err != nil { - oauth.Session.CallbackError = &OAuthCallbackGetTokensError{Err: err} - go oauth.Session.Server.Shutdown(oauth.Session.Context) + getTokensErr := oauth.getTokensWithAuthCode(extractedCode) + if getTokensErr != nil { + oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: errorMessage, Err: getTokensErr} return } - - // Shutdown the server as we're done listening - go oauth.Session.Server.Shutdown(oauth.Session.Context) } func (oauth *OAuth) Update(fsm *fsm.FSM, logger *log.FileLogger) { @@ -235,19 +237,20 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm. // Starts the OAuth exchange for eduvpn. func (oauth *OAuth) start(name string) error { + errorMessage := "failed starting OAuth exchange" if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) { - return &fsm.FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED} + return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()} } // Generate the state state, stateErr := genState() if stateErr != nil { - return &OAuthInitializeError{Err: stateErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr} } // Generate the verifier and challenge verifier, verifierErr := genVerifier() if verifierErr != nil { - return &OAuthInitializeError{Err: verifierErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: verifierErr} } challenge := genChallengeS256(verifier) @@ -264,7 +267,7 @@ func (oauth *OAuth) start(name string) error { authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters) if urlErr != nil { - return &OAuthInitializeError{Err: urlErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client @@ -277,34 +280,36 @@ func (oauth *OAuth) start(name string) error { // Error definitions func (oauth *OAuth) Finish() error { + errorMessage := "failed finishing OAuth" if !oauth.FSM.HasTransition(fsm.AUTHORIZED) { - return &fsm.FSMWrongStateError{Got: oauth.FSM.Current, Want: fsm.AUTHORIZED} + return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.AUTHORIZED}.CustomError()} } tokenErr := oauth.getTokensWithCallback() if tokenErr != nil { - return &OAuthFinishError{Err: tokenErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: tokenErr} } oauth.FSM.GoTransition(fsm.AUTHORIZED) return nil } func (oauth *OAuth) Cancel() { - oauth.Session.CallbackError = &OAuthCancelledCallbackError{} + oauth.Session.CallbackError = &types.WrappedErrorMessage{Message: "failed cancelling OAuth", Err: &OAuthCancelledCallbackError{}} oauth.Session.Server.Shutdown(oauth.Session.Context) } func (oauth *OAuth) Login(name string) error { + errorMessage := "failed OAuth login" authInitializeErr := oauth.start(name) if authInitializeErr != nil { - return &OAuthLoginError{Err: authInitializeErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr} } oauthErr := oauth.Finish() if oauthErr != nil { - return &OAuthLoginError{Err: oauthErr} + return &types.WrappedErrorMessage{Message: errorMessage, Err: oauthErr} } return nil } @@ -342,31 +347,6 @@ func (e *OAuthCancelledCallbackError) Error() string { return fmt.Sprintf("client cancelled OAuth") } -type OAuthGenStateError struct { - Err error -} - -func (e *OAuthGenStateError) Error() string { - return fmt.Sprintf("failed generating state with error: %v", e.Err) -} - -type OAuthGenVerifierError struct { - Err error -} - -func (e *OAuthGenVerifierError) Error() string { - return fmt.Sprintf("failed generating verifier with error: %v", e.Err) -} - -type OAuthCallbackError struct { - Addr string - Err error -} - -func (e *OAuthCallbackError) Error() string { - return fmt.Sprintf("failed callback: %s with error: %v", e.Addr, e.Err) -} - type OAuthCallbackParameterError struct { Parameter string URL string @@ -384,51 +364,3 @@ type OAuthCallbackStateMatchError struct { func (e *OAuthCallbackStateMatchError) Error() string { return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState) } - -type OAuthCallbackGetTokensError struct { - Err error -} - -func (e *OAuthCallbackGetTokensError) Error() string { - return fmt.Sprintf("failed getting tokens with error: %v", e.Err) -} - -type OAuthFinishError struct { - Err error -} - -func (e *OAuthFinishError) Error() string { - return fmt.Sprintf("failed finishing OAuth with error: %v", e.Err) -} - -type OAuthLoginError struct { - Err error -} - -func (e *OAuthLoginError) Error() string { - return fmt.Sprintf("failed OAuth logging in with error: %v", e.Err) -} - -type OAuthInitializeError struct { - Err error -} - -func (e *OAuthInitializeError) Error() string { - return fmt.Sprintf("failed initializing OAuth with error: %v", e.Err) -} - -type OAuthAuthError struct { - Err error -} - -func (e *OAuthAuthError) Error() string { - return fmt.Sprintf("failed getting tokens with auth code for OAuth with error: %v", e.Err) -} - -type OAuthRefreshError struct { - Err error -} - -func (e *OAuthRefreshError) Error() string { - return fmt.Sprintf("failed refreshing tokens for OAuth with error: %v", e.Err) -} |
