diff options
Diffstat (limited to 'internal/oauth/oauth.go')
| -rw-r--r-- | internal/oauth/oauth.go | 39 |
1 files changed, 20 insertions, 19 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index f2e7719..f9bf164 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -123,12 +123,12 @@ type exchangeSession struct { // AccessToken gets the OAuth access token used for contacting the server API // It returns the access token as a string, possibly obtained fresh using the Refresh Token // If the token cannot be obtained, an error is returned and the token is an empty string. -func (oauth *OAuth) AccessToken() (string, error) { +func (oauth *OAuth) AccessToken(ctx context.Context) (string, error) { tl := oauth.token if tl == nil { return "", errors.New("No token structure available") } - return tl.Access() + return tl.Access(ctx) } // setupListener sets up an OAuth listener @@ -147,7 +147,7 @@ func (oauth *OAuth) setupListener() error { // tokensWithCallback gets the OAuth tokens using a local web server // If it was unsuccessful it returns an error. -func (oauth *OAuth) tokensWithCallback() error { +func (oauth *OAuth) tokensWithCallback(ctx context.Context) error { if oauth.session.Listener == nil { return errors.New("failed getting tokens with callback: no listener") } @@ -159,7 +159,7 @@ func (oauth *OAuth) tokensWithCallback() error { // A bit overkill maybe for a local server but good to define anyways ReadHeaderTimeout: 60 * time.Second, } - defer s.Shutdown(context.Background()) //nolint:errcheck + defer s.Shutdown(ctx) //nolint:errcheck // Use a sync.Once to only handle one request up until we shutdown the server var once sync.Once @@ -174,7 +174,12 @@ func (oauth *OAuth) tokensWithCallback() error { oauth.session.ErrChan <- errors.WrapPrefix(err, "failed getting tokens with callback", 0) } }() - return <-oauth.session.ErrChan + select { + case err := <-oauth.session.ErrChan: + return err + case <-ctx.Done(): + return errors.WrapPrefix(context.Canceled, "stopped oauth server", 0) + } } // tokenResponse fills the OAuth token response structure by the response @@ -221,7 +226,7 @@ func (oauth *OAuth) Token() Token { // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 // If it was unsuccessful it returns an error. -func (oauth *OAuth) tokensWithAuthCode(authCode string) error { +func (oauth *OAuth) tokensWithAuthCode(ctx context.Context, authCode string) error { // Make sure the verifier is set as the parameter // so that the server can verify that we are the actual owner of the authorization code u := oauth.TokenURL @@ -245,7 +250,7 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error { now := time.Now() // We are sure that we have a http client because we have initialized it when starting the exchange - _, body, err := oauth.httpClient.PostWithOpts(u, opts) + _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts) if err != nil { return err } @@ -274,7 +279,7 @@ func (oauth *OAuth) UpdateTokens(t Token) { // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 // If it was unsuccessful it returns an error. -func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error) { +func (oauth *OAuth) refreshResponse(ctx context.Context, r string) (*TokenResponse, time.Time, error) { u := oauth.TokenURL if oauth.token == nil { return nil, time.Time{}, errors.New("No oauth token structure in refresh") @@ -298,7 +303,7 @@ func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error) oauth.httpClient = httpw.NewClient() } - _, body, err := oauth.httpClient.PostWithOpts(u, opts) + _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts) if err != nil { return nil, time.Time{}, err } @@ -394,7 +399,7 @@ func (s *exchangeSession) Authcode(url *url.URL) (string, error) { // tokenHandler gets the tokens using the authorization code that is obtained through the url // This function is called by the http handler and returns an error if the tokens cannot be obtained -func (oauth *OAuth) tokenHandler(url *url.URL) error { +func (oauth *OAuth) tokenHandler(ctx context.Context, url *url.URL) error { // Get the authorization code c, err := oauth.session.Authcode(url) if err != nil { @@ -402,14 +407,15 @@ func (oauth *OAuth) tokenHandler(url *url.URL) error { } // Now that we have obtained the authorization code, we can move to the next step: // Obtaining the access and refresh tokens - return oauth.tokensWithAuthCode(c) + return oauth.tokensWithAuthCode(ctx, c) } // Handler is the function used to get the OAuth tokens using an authorization code callback // The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 // It sends an error to the session channel (can be nil) func (oauth *OAuth) Handler(w http.ResponseWriter, req *http.Request) { - err := oauth.tokenHandler(req.URL) + // TODO: should this be something else than context background? + err := oauth.tokenHandler(context.Background(), req.URL) if err != nil { _ = writeResponseHTML( w, @@ -507,17 +513,12 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s // Exchange starts the OAuth exchange by getting the tokens with the redirect callback // If it was unsuccessful it returns an error. -func (oauth *OAuth) Exchange() error { +func (oauth *OAuth) Exchange(ctx context.Context) error { // If there is no HTTP client defined, create a new one if oauth.httpClient == nil { oauth.httpClient = httpw.NewClient() } - return oauth.tokensWithCallback() -} - -// Cancel cancels the existing OAuth server by sending a cancel error to the channel -func (oauth *OAuth) Cancel() { - oauth.session.ErrChan <- errors.Wrap(&CancelledCallbackError{}, 0) + return oauth.tokensWithCallback(ctx) } type CancelledCallbackError struct{} |
