diff options
Diffstat (limited to 'internal/oauth/oauth.go')
| -rw-r--r-- | internal/oauth/oauth.go | 108 |
1 files changed, 49 insertions, 59 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 1dbdad4..3d71c67 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -93,9 +93,6 @@ type OAuth struct { // ExchangeSession is a structure that gets passed to the callback for easy access to the current state. type ExchangeSession struct { - // CallbackError indicates an error returned by the server - CallbackError error - // ClientID is the ID of the OAuth client ClientID string @@ -111,11 +108,11 @@ type ExchangeSession struct { // Context is the context used for cancellation Context context.Context - // Server is the server of the session - Server *http.Server - // Listener is the listener where the servers 'listens' on Listener net.Listener + + // ErrChan is used to send the error from the handler + ErrChan chan error } // AccessToken gets the OAuth access token used for contacting the server API @@ -173,18 +170,21 @@ func (oauth *OAuth) tokensWithCallback() error { } mux := http.NewServeMux() // server /callback over the listener address - oauth.session.Server = &http.Server{ + s := &http.Server{ Handler: mux, // Define a default 60 second header read timeout to protect against a Slowloris Attack // A bit overkill maybe for a local server but good to define anyways ReadHeaderTimeout: 60 * time.Second, } - mux.HandleFunc("/callback", oauth.Callback) + defer s.Shutdown(oauth.session.Context) + mux.HandleFunc("/callback", oauth.Handler) - if err := oauth.session.Server.Serve(oauth.session.Listener); err != http.ErrServerClosed { - return errors.WrapPrefix(err, "failed getting tokens with callback", 0) - } - return oauth.session.CallbackError + go func() { + if err := s.Serve(oauth.session.Listener); err != http.ErrServerClosed { + oauth.session.ErrChan <- errors.WrapPrefix(err, "failed getting tokens with callback", 0) + } + }() + return <-oauth.session.ErrChan } // fillToken fills the OAuth token structure by the response @@ -328,67 +328,60 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro return t.Execute(w, oauthResponseHTML{Title: title, Message: message}) } -// Callback is the public function used to get the OAuth tokens using an authorization code callback -// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 -func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { - // Shutdown after we're done - defer func() { - // writing the html is best effort - if oauth.session.CallbackError != nil { - _ = writeResponseHTML( - w, - "Authorization Failed", - "The authorization has failed. See the log file for more information.", - ) - } else { - _ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.") - } - if oauth.session.Server != nil { - go func() { - _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck - }() - } - }() - +func (s *ExchangeSession) Authcode(url *url.URL) (string, error) { // ISS: https://www.rfc-editor.org/rfc/rfc9207.html // TODO: Make this a required parameter in the future - q := req.URL.Query() + q := url.Query() iss := q.Get("iss") - if iss != "" { - if oauth.session.ISS != iss { - oauth.session.CallbackError = errors.Errorf("failed matching ISS; expected '%s' got '%s'", - oauth.session.ISS, iss) - return - } + if iss != "" && s.ISS != iss { + return "", errors.Errorf("failed matching ISS; expected '%s' got '%s'", s.ISS, iss) } // Make sure the state is present and matches to protect against cross-site request forgeries // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 state := q.Get("state") if state == "" { - oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'state' from '%s'", req.URL) - return + return "", errors.Errorf("failed retrieving parameter 'state' from '%s'", url) } // The state is the first entry - if state != oauth.session.State { - oauth.session.CallbackError = errors.Errorf("failed matching state; expected '%s' got '%s'", - oauth.session.State, state) - return + if state != s.State { + return "", errors.Errorf("failed matching state; expected '%s' got '%s'", s.State, state) } // No authorization code code := q.Get("code") if code == "" { - oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'code' from '%s'", req.URL) - return + return "", errors.Errorf("failed retrieving parameter 'code' from '%s'", url) } + return code, nil +} + +func (oauth *OAuth) TokenHandler(url *url.URL) error { + // Get the authorization code + c, err := oauth.session.Authcode(url) + if err != nil { + return err + } // Now that we have obtained the authorization code, we can move to the next step: // Obtaining the access and refresh tokens - if err := oauth.tokensWithAuthCode(code); err != nil { - oauth.session.CallbackError = errors.WrapPrefix(err, "failed callback to retrieve the authorization code", 0) - return - } + return oauth.tokensWithAuthCode(c) +} + +// Callback is the public function used to get the OAuth tokens using an authorization code callback +// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 +func (oauth *OAuth) Handler(w http.ResponseWriter, req *http.Request) { + err := oauth.TokenHandler(req.URL) + if err != nil { + _ = writeResponseHTML( + w, + "Authorization Failed", + "The authorization has failed. See the log file for more information.", + ) + } else { + _ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.") + } + oauth.session.ErrChan <- err } // Init initializes OAuth with the following parameters: @@ -430,6 +423,7 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s ISS: oauth.ISS, State: state, Verifier: v, + ErrChan: make(chan error), } // set up the listener to get the redirect URI @@ -468,13 +462,9 @@ func (oauth *OAuth) Exchange() error { return oauth.tokensWithCallback() } -// Cancel cancels the existing OAuth -// TODO: Use context for this. +// Cancel cancels the existing OAuth server by sending a cancel error to the channel func (oauth *OAuth) Cancel() { - oauth.session.CallbackError = errors.Wrap(&CancelledCallbackError{}, 0) - if oauth.session.Server != nil { - _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck - } + oauth.session.ErrChan <- errors.Wrap(&CancelledCallbackError{}, 0) } type CancelledCallbackError struct{} |
