diff options
| author | Aleksandar Pesic <peske.nis@gmail.com> | 2022-12-04 21:48:20 +0100 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-12-12 13:26:51 +0100 |
| commit | 3ac1d35257b56cca92ad0eb7f4d18abb366cf105 (patch) | |
| tree | 432db14d1f92a252518f371be420fa0d3ef044c8 /internal/oauth | |
| parent | 37bca013bd4405548b274ac473acf959ad661ee6 (diff) | |
simplify error handling
fixes #6
Signed-off-by: Aleksandar Pesic <peske.nis@gmail.com>
Diffstat (limited to 'internal/oauth')
| -rw-r--r-- | internal/oauth/oauth.go | 307 |
1 files changed, 105 insertions, 202 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 802295d..3dcd3d3 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -1,4 +1,4 @@ -// package oauth implement an oauth client defined in e.g. rfc 6749 +// Package oauth implement an oauth client defined in e.g. rfc 6749 // However, we try to follow some recommendations from the v2.1 oauth draft RFC // Some specific things we implement here: // - PKCE (RFC 7636) @@ -10,7 +10,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" - "errors" "fmt" "html/template" "net" @@ -20,7 +19,7 @@ import ( httpw "github.com/eduvpn/eduvpn-common/internal/http" "github.com/eduvpn/eduvpn-common/internal/util" - "github.com/eduvpn/eduvpn-common/types" + "github.com/go-errors/errors" ) // genState generates a random base64 string to be used for state @@ -31,13 +30,13 @@ import ( // client. // We implement it similarly to the verifier. func genState() (string, error) { - randomBytes, err := util.MakeRandomByteSlice(32) + bts, err := util.MakeRandomByteSlice(32) if err != nil { - return "", types.NewWrappedError("failed generating an OAuth state", err) + return "", err } - // For consistency we also use raw url encoding here - return base64.RawURLEncoding.EncodeToString(randomBytes), nil + // For consistency, we also use raw url encoding here + return base64.RawURLEncoding.EncodeToString(bts), nil } // genChallengeS256 generates a sha256 base64 challenge from a verifier @@ -68,10 +67,7 @@ func genChallengeS256(verifier string) string { func genVerifier() (string, error) { randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { - return "", types.NewWrappedError( - "failed generating an OAuth verifier", - err, - ) + return "", err } return base64.RawURLEncoding.EncodeToString(randomBytes), nil @@ -89,10 +85,10 @@ type OAuth struct { TokenURL string `json:"token_url"` // session is the internal in progress OAuth session - session ExchangeSession `json:"-"` + session ExchangeSession // Token is where the access and refresh tokens are stored along with the timestamps - token Token `json:"-"` + token Token } // ExchangeSession is a structure that gets passed to the callback for easy access to the current state. @@ -126,39 +122,31 @@ type ExchangeSession struct { // It returns the access token as a string, possibly obtained fresh using the Refresh Token // If the token cannot be obtained, an error is returned and the token is an empty string. func (oauth *OAuth) AccessToken() (string, error) { - errorMessage := "failed getting access token" - tokens := oauth.token + ts := oauth.token // We have tokens... // The tokens are not expired yet - // So they should be valid, re-authorization not needed - if !tokens.Expired() { - return tokens.access, nil + // So they should be valid, re-login not needed + if !ts.Expired() { + return ts.access, nil } // Check if refresh is even possible by doing a simple check if the refresh token is empty // This is not needed but reduces API calls to the server - if tokens.refresh == "" { - return "", types.NewWrappedError( - errorMessage, - &TokensInvalidError{Cause: "no refresh token is present"}, - ) + if ts.refresh == "" { + return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0) } // Otherwise refresh and then later return the access token if we are successful - refreshErr := oauth.tokensWithRefresh() - if refreshErr != nil { + err := oauth.tokensWithRefresh() + if err != nil { // We have failed to ensure the tokens due to refresh not working - return "", types.NewWrappedError( - errorMessage, - &TokensInvalidError{ - Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr), - }, - ) + return "", errors.Wrap( + &TokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", err)}, 0) } // We have obtained new tokens with refresh - return tokens.access, nil + return ts.access, nil } // setupListener sets up an OAuth listener @@ -166,24 +154,22 @@ func (oauth *OAuth) AccessToken() (string, error) { // @see https://www.ietf.org/archive/id/draft-ietf-oauth-v2-1-07.html#section-8.4.2 // "Loopback Interface Redirection". func (oauth *OAuth) setupListener() error { - errorMessage := "failed setting up listener" oauth.session.Context = context.Background() // create a listener - listener, listenerErr := net.Listen("tcp", "127.0.0.1:0") - if listenerErr != nil { - return types.NewWrappedError(errorMessage, listenerErr) + lst, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return errors.WrapPrefix(err, "net.Listen failed", 0) } - oauth.session.Listener = listener + oauth.session.Listener = lst return nil } // tokensWithCallback gets the OAuth tokens using a local web server // If it was unsuccessful it returns an error. func (oauth *OAuth) tokensWithCallback() error { - errorMessage := "failed getting tokens with callback" if oauth.session.Listener == nil { - return types.NewWrappedError(errorMessage, errors.New("no listener")) + return errors.Errorf("failed getting tokens with callback: no listener") } mux := http.NewServeMux() // server /callback over the listener address @@ -196,7 +182,7 @@ func (oauth *OAuth) tokensWithCallback() error { mux.HandleFunc("/callback", oauth.Callback) if err := oauth.session.Server.Serve(oauth.session.Listener); err != http.ErrServerClosed { - return types.NewWrappedError(errorMessage, err) + return errors.WrapPrefix(err, "failed getting tokens with callback", 0) } return oauth.session.CallbackError } @@ -205,23 +191,18 @@ func (oauth *OAuth) tokensWithCallback() error { // It calculates the expired timestamp by having a 'startTime' passed to it // The URL that is input here is used for additional context. func (oauth *OAuth) fillToken(response []byte, startTime time.Time, url string) error { - responseStructure := TokenResponse{} - - jsonErr := json.Unmarshal(response, &responseStructure) - if jsonErr != nil { - return types.NewWrappedError( - "failed filling OAuth tokens", - &httpw.ParseJSONError{URL: url, Body: string(response), Err: jsonErr}, - ) - } - - internalStructure := Token{} - internalStructure.expiredTimestamp = startTime.Add( - time.Second * time.Duration(responseStructure.Expires), - ) - internalStructure.access = responseStructure.Access - internalStructure.refresh = responseStructure.Refresh - oauth.token = internalStructure + res := TokenResponse{} + + err := json.Unmarshal(response, &res) + if err != nil { + return errors.WrapPrefix(err, "failed filling OAuth tokens from "+url, 0) + } + + oauth.token = Token{ + access: res.Access, + refresh: res.Refresh, + expiredTimestamp: startTime.Add(time.Second * time.Duration(res.Expires)), + } return nil } @@ -240,14 +221,13 @@ func (oauth *OAuth) SetTokenRenew() { // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 // If it was unsuccessful it returns an error. func (oauth *OAuth) tokensWithAuthCode(authCode string) error { - errorMessage := "failed getting tokens with the authorization code" // Make sure the verifier is set as the parameter // so that the server can verify that we are the actual owner of the authorization code - reqURL := oauth.TokenURL + u := oauth.TokenURL - port, portErr := oauth.ListenerPort() - if portErr != nil { - return types.NewWrappedError(errorMessage, portErr) + port, err := oauth.ListenerPort() + if err != nil { + return err } data := url.Values{ @@ -257,21 +237,17 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error { "grant_type": {"authorization_code"}, "redirect_uri": {fmt.Sprintf("http://127.0.0.1:%d/callback", port)}, } - headers := http.Header{ + h := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, } - opts := &httpw.OptionalParams{Headers: headers, Body: data} - currentTime := time.Now() - _, body, bodyErr := httpw.PostWithOpts(reqURL, opts) - if bodyErr != nil { - return types.NewWrappedError(errorMessage, bodyErr) + opts := &httpw.OptionalParams{Headers: h, Body: data} + now := time.Now() + _, body, err := httpw.PostWithOpts(u, opts) + if err != nil { + return err } - fillErr := oauth.fillToken(body, currentTime, reqURL) - if fillErr != nil { - return types.NewWrappedError(errorMessage, fillErr) - } - return nil + return oauth.fillToken(body, now, u) } // tokensWithRefresh gets the access and refresh tokens with a previously received refresh token @@ -279,27 +255,22 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error { // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 // If it was unsuccessful it returns an error. func (oauth *OAuth) tokensWithRefresh() error { - errorMessage := "failed getting tokens with the refresh token" - reqURL := oauth.TokenURL + u := oauth.TokenURL data := url.Values{ "refresh_token": {oauth.token.refresh}, "grant_type": {"refresh_token"}, } - headers := http.Header{ + h := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, } - opts := &httpw.OptionalParams{Headers: headers, Body: data} - currentTime := time.Now() - _, body, bodyErr := httpw.PostWithOpts(reqURL, opts) - if bodyErr != nil { - return types.NewWrappedError(errorMessage, bodyErr) + opts := &httpw.OptionalParams{Headers: h, Body: data} + now := time.Now() + _, body, err := httpw.PostWithOpts(u, opts) + if err != nil { + return err } - fillErr := oauth.fillToken(body, currentTime, reqURL) - if fillErr != nil { - return types.NewWrappedError(errorMessage, fillErr) - } - return nil + return oauth.fillToken(body, now, u) } // responseTemplate is the HTML template for the OAuth authorized response @@ -349,27 +320,17 @@ type oauthResponseHTML struct { // writeResponseHTML writes the OAuth response using a response writer and the title + message // If it was unsuccessful it returns an error. func writeResponseHTML(w http.ResponseWriter, title string, message string) error { - errorMessage := "failed writing response HTML" - template, templateErr := template.New("oauth-response").Parse(responseTemplate) - if templateErr != nil { - return types.NewWrappedError(errorMessage, templateErr) + t, err := template.New("oauth-response").Parse(responseTemplate) + if err != nil { + return errors.WrapPrefix(err, "failed writing response HTML", 0) } - executeErr := template.Execute(w, oauthResponseHTML{ - Title: title, - Message: message, - }) - if executeErr != nil { - return types.NewWrappedError(errorMessage, executeErr) - } - return nil + return t.Execute(w, oauthResponseHTML{Title: title, Message: message}) } // Callback is the public function used to get the OAuth tokens using an authorization code callback // The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { - errorMessage := "failed callback to retrieve the authorization code" - // Shutdown after we're done defer func() { // writing the html is best effort @@ -383,64 +344,49 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { _ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.") } if oauth.session.Server != nil { - go oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck + go func() { + _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck + }() } }() // ISS: https://www.rfc-editor.org/rfc/rfc9207.html // TODO: Make this a required parameter in the future - urlQuery := req.URL.Query() - extractedISS := urlQuery.Get("iss") - if extractedISS != "" { - if oauth.session.ISS != extractedISS { - oauth.session.CallbackError = types.NewWrappedError( - errorMessage, - &CallbackISSMatchError{ISS: extractedISS, ExpectedISS: oauth.session.ISS}, - ) + q := req.URL.Query() + iss := q.Get("iss") + if iss != "" { + if oauth.session.ISS != iss { + oauth.session.CallbackError = errors.Errorf("failed matching ISS; expected '%s' got '%s'", + oauth.session.ISS, iss) return } } // Make sure the state is present and matches to protect against cross-site request forgeries // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 - extractedState := urlQuery.Get("state") - if extractedState == "" { - oauth.session.CallbackError = types.NewWrappedError( - errorMessage, - &CallbackParameterError{Parameter: "state", URL: req.URL.String()}, - ) + state := q.Get("state") + if state == "" { + oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'state' from '%s'", req.URL) return } // The state is the first entry - if extractedState != oauth.session.State { - oauth.session.CallbackError = types.NewWrappedError( - errorMessage, - &CallbackStateMatchError{ - State: extractedState, - ExpectedState: oauth.session.State, - }, - ) + if state != oauth.session.State { + oauth.session.CallbackError = errors.Errorf("failed matching state; expected '%s' got '%s'", + oauth.session.State, state) return } // No authorization code - extractedCode := urlQuery.Get("code") - if extractedCode == "" { - oauth.session.CallbackError = types.NewWrappedError( - errorMessage, - &CallbackParameterError{Parameter: "code", URL: req.URL.String()}, - ) + code := q.Get("code") + if code == "" { + oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'code' from '%s'", req.URL) return } // Now that we have obtained the authorization code, we can move to the next step: // Obtaining the access and refresh tokens - getTokensErr := oauth.tokensWithAuthCode(extractedCode) - if getTokensErr != nil { - oauth.session.CallbackError = types.NewWrappedError( - errorMessage, - getTokensErr, - ) + if err := oauth.tokensWithAuthCode(code); err != nil { + oauth.session.CallbackError = errors.WrapPrefix(err, "failed callback to retrieve the authorization code", 0) return } } @@ -457,94 +403,78 @@ func (oauth *OAuth) Init(iss string, baseAuthorizationURL string, tokenURL strin // ListenerPort gets the listener for the OAuth web server // It returns the port as an integer and an error if there is any. -func (oauth OAuth) ListenerPort() (int, error) { - errorMessage := "failed to get listener port" - +func (oauth *OAuth) ListenerPort() (int, error) { if oauth.session.Listener == nil { - return 0, types.NewWrappedError(errorMessage, errors.New("no OAuth listener")) + return 0, errors.Errorf("failed to get listener port") } return oauth.session.Listener.Addr().(*net.TCPAddr).Port, nil } // AuthURL gets the authorization url to start the OAuth procedure. func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (string, error) { - errorMessage := "failed starting OAuth exchange" - // Generate the verifier and challenge - verifier, verifierErr := genVerifier() - if verifierErr != nil { - return "", types.NewWrappedError(errorMessage, verifierErr) + v, err := genVerifier() + if err != nil { + return "", errors.WrapPrefix(err, "genVerifier error", 0) } - challenge := genChallengeS256(verifier) // Generate the state - state, stateErr := genState() - if stateErr != nil { - return "", types.NewWrappedError(errorMessage, stateErr) + state, err := genState() + if err != nil { + return "", errors.WrapPrefix(err, "genState error", 0) } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client - oauthSession := ExchangeSession{ + oauth.session = ExchangeSession{ ClientID: name, ISS: oauth.ISS, State: state, - Verifier: verifier, + Verifier: v, } - oauth.session = oauthSession // set up the listener to get the redirect URI - listenerErr := oauth.setupListener() - if listenerErr != nil { - return "", types.NewWrappedError(errorMessage, stateErr) + if err = oauth.setupListener(); err != nil { + return "", errors.WrapPrefix(err, "oauth.setupListener error", 0) } // Get the listener port - port, portErr := oauth.ListenerPort() - if portErr != nil { - return "", types.NewWrappedError(errorMessage, portErr) + port, err := oauth.ListenerPort() + if err != nil { + return "", errors.WrapPrefix(err, "oauth.ListenerPort error", 0) } - parameters := map[string]string{ + params := map[string]string{ "client_id": name, "code_challenge_method": "S256", - "code_challenge": challenge, + "code_challenge": genChallengeS256(v), "response_type": "code", "scope": "config", "state": state, "redirect_uri": fmt.Sprintf("http://127.0.0.1:%d/callback", port), } - authURL, urlErr := httpw.ConstructURL(oauth.BaseAuthorizationURL, parameters) + u, err := httpw.ConstructURL(oauth.BaseAuthorizationURL, params) - if urlErr != nil { - return "", types.NewWrappedError(errorMessage, urlErr) + if err != nil { + return "", errors.WrapPrefix(err, "httpw.ConstructURL error", 0) } // Return the url processed - return postProcessAuth(authURL), nil + return postProcessAuth(u), nil } // Exchange starts the OAuth exchange by getting the tokens with the redirect callback // If it was unsuccessful it returns an error. func (oauth *OAuth) Exchange() error { - tokenErr := oauth.tokensWithCallback() - - if tokenErr != nil { - return types.NewWrappedError("failed finishing OAuth", tokenErr) - } - return nil + return oauth.tokensWithCallback() } // Cancel cancels the existing OAuth // TODO: Use context for this. func (oauth *OAuth) Cancel() { - oauth.session.CallbackError = types.NewWrappedErrorLevel( - types.ErrInfo, - "cancelled OAuth", - &CancelledCallbackError{}, - ) + oauth.session.CallbackError = errors.Wrap(&CancelledCallbackError{}, 0) if oauth.session.Server != nil { - oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck + _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck } } @@ -554,33 +484,6 @@ func (e *CancelledCallbackError) Error() string { return "client cancelled OAuth" } -type CallbackParameterError struct { - Parameter string - URL string -} - -func (e *CallbackParameterError) Error() string { - return fmt.Sprintf("failed retrieving parameter: %s in url: %s", e.Parameter, e.URL) -} - -type CallbackStateMatchError struct { - State string - ExpectedState string -} - -func (e *CallbackStateMatchError) Error() string { - return fmt.Sprintf("failed matching state, got: %s, want: %s", e.State, e.ExpectedState) -} - -type CallbackISSMatchError struct { - ISS string - ExpectedISS string -} - -func (e *CallbackISSMatchError) Error() string { - return fmt.Sprintf("failed matching ISS, got: %s, want: %s", e.ISS, e.ExpectedISS) -} - type TokensInvalidError struct { Cause string } |
