summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/oauth/oauth.go108
1 files changed, 49 insertions, 59 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 1dbdad4..3d71c67 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -93,9 +93,6 @@ type OAuth struct {
// ExchangeSession is a structure that gets passed to the callback for easy access to the current state.
type ExchangeSession struct {
- // CallbackError indicates an error returned by the server
- CallbackError error
-
// ClientID is the ID of the OAuth client
ClientID string
@@ -111,11 +108,11 @@ type ExchangeSession struct {
// Context is the context used for cancellation
Context context.Context
- // Server is the server of the session
- Server *http.Server
-
// Listener is the listener where the servers 'listens' on
Listener net.Listener
+
+ // ErrChan is used to send the error from the handler
+ ErrChan chan error
}
// AccessToken gets the OAuth access token used for contacting the server API
@@ -173,18 +170,21 @@ func (oauth *OAuth) tokensWithCallback() error {
}
mux := http.NewServeMux()
// server /callback over the listener address
- oauth.session.Server = &http.Server{
+ s := &http.Server{
Handler: mux,
// Define a default 60 second header read timeout to protect against a Slowloris Attack
// A bit overkill maybe for a local server but good to define anyways
ReadHeaderTimeout: 60 * time.Second,
}
- mux.HandleFunc("/callback", oauth.Callback)
+ defer s.Shutdown(oauth.session.Context)
+ mux.HandleFunc("/callback", oauth.Handler)
- if err := oauth.session.Server.Serve(oauth.session.Listener); err != http.ErrServerClosed {
- return errors.WrapPrefix(err, "failed getting tokens with callback", 0)
- }
- return oauth.session.CallbackError
+ go func() {
+ if err := s.Serve(oauth.session.Listener); err != http.ErrServerClosed {
+ oauth.session.ErrChan <- errors.WrapPrefix(err, "failed getting tokens with callback", 0)
+ }
+ }()
+ return <-oauth.session.ErrChan
}
// fillToken fills the OAuth token structure by the response
@@ -328,67 +328,60 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro
return t.Execute(w, oauthResponseHTML{Title: title, Message: message})
}
-// Callback is the public function used to get the OAuth tokens using an authorization code callback
-// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
-func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
- // Shutdown after we're done
- defer func() {
- // writing the html is best effort
- if oauth.session.CallbackError != nil {
- _ = writeResponseHTML(
- w,
- "Authorization Failed",
- "The authorization has failed. See the log file for more information.",
- )
- } else {
- _ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.")
- }
- if oauth.session.Server != nil {
- go func() {
- _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
- }()
- }
- }()
-
+func (s *ExchangeSession) Authcode(url *url.URL) (string, error) {
// ISS: https://www.rfc-editor.org/rfc/rfc9207.html
// TODO: Make this a required parameter in the future
- q := req.URL.Query()
+ q := url.Query()
iss := q.Get("iss")
- if iss != "" {
- if oauth.session.ISS != iss {
- oauth.session.CallbackError = errors.Errorf("failed matching ISS; expected '%s' got '%s'",
- oauth.session.ISS, iss)
- return
- }
+ if iss != "" && s.ISS != iss {
+ return "", errors.Errorf("failed matching ISS; expected '%s' got '%s'", s.ISS, iss)
}
// Make sure the state is present and matches to protect against cross-site request forgeries
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15
state := q.Get("state")
if state == "" {
- oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'state' from '%s'", req.URL)
- return
+ return "", errors.Errorf("failed retrieving parameter 'state' from '%s'", url)
}
// The state is the first entry
- if state != oauth.session.State {
- oauth.session.CallbackError = errors.Errorf("failed matching state; expected '%s' got '%s'",
- oauth.session.State, state)
- return
+ if state != s.State {
+ return "", errors.Errorf("failed matching state; expected '%s' got '%s'", s.State, state)
}
// No authorization code
code := q.Get("code")
if code == "" {
- oauth.session.CallbackError = errors.Errorf("failed retrieving parameter 'code' from '%s'", req.URL)
- return
+ return "", errors.Errorf("failed retrieving parameter 'code' from '%s'", url)
}
+ return code, nil
+}
+
+func (oauth *OAuth) TokenHandler(url *url.URL) error {
+ // Get the authorization code
+ c, err := oauth.session.Authcode(url)
+ if err != nil {
+ return err
+ }
// Now that we have obtained the authorization code, we can move to the next step:
// Obtaining the access and refresh tokens
- if err := oauth.tokensWithAuthCode(code); err != nil {
- oauth.session.CallbackError = errors.WrapPrefix(err, "failed callback to retrieve the authorization code", 0)
- return
- }
+ return oauth.tokensWithAuthCode(c)
+}
+
+// Callback is the public function used to get the OAuth tokens using an authorization code callback
+// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
+func (oauth *OAuth) Handler(w http.ResponseWriter, req *http.Request) {
+ err := oauth.TokenHandler(req.URL)
+ if err != nil {
+ _ = writeResponseHTML(
+ w,
+ "Authorization Failed",
+ "The authorization has failed. See the log file for more information.",
+ )
+ } else {
+ _ = writeResponseHTML(w, "Authorized", "The client has been successfully authorized. You can close this browser window.")
+ }
+ oauth.session.ErrChan <- err
}
// Init initializes OAuth with the following parameters:
@@ -430,6 +423,7 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s
ISS: oauth.ISS,
State: state,
Verifier: v,
+ ErrChan: make(chan error),
}
// set up the listener to get the redirect URI
@@ -468,13 +462,9 @@ func (oauth *OAuth) Exchange() error {
return oauth.tokensWithCallback()
}
-// Cancel cancels the existing OAuth
-// TODO: Use context for this.
+// Cancel cancels the existing OAuth server by sending a cancel error to the channel
func (oauth *OAuth) Cancel() {
- oauth.session.CallbackError = errors.Wrap(&CancelledCallbackError{}, 0)
- if oauth.session.Server != nil {
- _ = oauth.session.Server.Shutdown(oauth.session.Context) //nolint:errcheck
- }
+ oauth.session.ErrChan <- errors.Wrap(&CancelledCallbackError{}, 0)
}
type CancelledCallbackError struct{}