diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2023-04-12 22:51:30 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2023-09-25 09:43:37 +0200 |
| commit | 2898723cbe9c2bd65995dc22d080c3067ebdf4b7 (patch) | |
| tree | 3dfa0a8c60973d2ccb78abd39422567fb1e19cdb | |
| parent | 056e0c17a72cf9000c0ed1771c1ab38449c726fd (diff) | |
OAuth: Pass a context around
| -rw-r--r-- | internal/oauth/oauth.go | 39 | ||||
| -rw-r--r-- | internal/oauth/oauth_test.go | 13 | ||||
| -rw-r--r-- | internal/oauth/token.go | 7 |
3 files changed, 31 insertions, 28 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index f2e7719..f9bf164 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -123,12 +123,12 @@ type exchangeSession struct { // AccessToken gets the OAuth access token used for contacting the server API // It returns the access token as a string, possibly obtained fresh using the Refresh Token // If the token cannot be obtained, an error is returned and the token is an empty string. -func (oauth *OAuth) AccessToken() (string, error) { +func (oauth *OAuth) AccessToken(ctx context.Context) (string, error) { tl := oauth.token if tl == nil { return "", errors.New("No token structure available") } - return tl.Access() + return tl.Access(ctx) } // setupListener sets up an OAuth listener @@ -147,7 +147,7 @@ func (oauth *OAuth) setupListener() error { // tokensWithCallback gets the OAuth tokens using a local web server // If it was unsuccessful it returns an error. -func (oauth *OAuth) tokensWithCallback() error { +func (oauth *OAuth) tokensWithCallback(ctx context.Context) error { if oauth.session.Listener == nil { return errors.New("failed getting tokens with callback: no listener") } @@ -159,7 +159,7 @@ func (oauth *OAuth) tokensWithCallback() error { // A bit overkill maybe for a local server but good to define anyways ReadHeaderTimeout: 60 * time.Second, } - defer s.Shutdown(context.Background()) //nolint:errcheck + defer s.Shutdown(ctx) //nolint:errcheck // Use a sync.Once to only handle one request up until we shutdown the server var once sync.Once @@ -174,7 +174,12 @@ func (oauth *OAuth) tokensWithCallback() error { oauth.session.ErrChan <- errors.WrapPrefix(err, "failed getting tokens with callback", 0) } }() - return <-oauth.session.ErrChan + select { + case err := <-oauth.session.ErrChan: + return err + case <-ctx.Done(): + return errors.WrapPrefix(context.Canceled, "stopped oauth server", 0) + } } // tokenResponse fills the OAuth token response structure by the response @@ -221,7 +226,7 @@ func (oauth *OAuth) Token() Token { // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 // If it was unsuccessful it returns an error. -func (oauth *OAuth) tokensWithAuthCode(authCode string) error { +func (oauth *OAuth) tokensWithAuthCode(ctx context.Context, authCode string) error { // Make sure the verifier is set as the parameter // so that the server can verify that we are the actual owner of the authorization code u := oauth.TokenURL @@ -245,7 +250,7 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error { now := time.Now() // We are sure that we have a http client because we have initialized it when starting the exchange - _, body, err := oauth.httpClient.PostWithOpts(u, opts) + _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts) if err != nil { return err } @@ -274,7 +279,7 @@ func (oauth *OAuth) UpdateTokens(t Token) { // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 // If it was unsuccessful it returns an error. -func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error) { +func (oauth *OAuth) refreshResponse(ctx context.Context, r string) (*TokenResponse, time.Time, error) { u := oauth.TokenURL if oauth.token == nil { return nil, time.Time{}, errors.New("No oauth token structure in refresh") @@ -298,7 +303,7 @@ func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error) oauth.httpClient = httpw.NewClient() } - _, body, err := oauth.httpClient.PostWithOpts(u, opts) + _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts) if err != nil { return nil, time.Time{}, err } @@ -394,7 +399,7 @@ func (s *exchangeSession) Authcode(url *url.URL) (string, error) { // tokenHandler gets the tokens using the authorization code that is obtained through the url // This function is called by the http handler and returns an error if the tokens cannot be obtained -func (oauth *OAuth) tokenHandler(url *url.URL) error { +func (oauth *OAuth) tokenHandler(ctx context.Context, url *url.URL) error { // Get the authorization code c, err := oauth.session.Authcode(url) if err != nil { @@ -402,14 +407,15 @@ func (oauth *OAuth) tokenHandler(url *url.URL) error { } // Now that we have obtained the authorization code, we can move to the next step: // Obtaining the access and refresh tokens - return oauth.tokensWithAuthCode(c) + return oauth.tokensWithAuthCode(ctx, c) } // Handler is the function used to get the OAuth tokens using an authorization code callback // The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 // It sends an error to the session channel (can be nil) func (oauth *OAuth) Handler(w http.ResponseWriter, req *http.Request) { - err := oauth.tokenHandler(req.URL) + // TODO: should this be something else than context background? + err := oauth.tokenHandler(context.Background(), req.URL) if err != nil { _ = writeResponseHTML( w, @@ -507,17 +513,12 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s // Exchange starts the OAuth exchange by getting the tokens with the redirect callback // If it was unsuccessful it returns an error. -func (oauth *OAuth) Exchange() error { +func (oauth *OAuth) Exchange(ctx context.Context) error { // If there is no HTTP client defined, create a new one if oauth.httpClient == nil { oauth.httpClient = httpw.NewClient() } - return oauth.tokensWithCallback() -} - -// Cancel cancels the existing OAuth server by sending a cancel error to the channel -func (oauth *OAuth) Cancel() { - oauth.session.ErrChan <- errors.Wrap(&CancelledCallbackError{}, 0) + return oauth.tokensWithCallback(ctx) } type CancelledCallbackError struct{} diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go index 8682e24..60ce5c7 100644 --- a/internal/oauth/oauth_test.go +++ b/internal/oauth/oauth_test.go @@ -1,6 +1,7 @@ package oauth import ( + "context" "encoding/json" "fmt" "net/url" @@ -60,7 +61,7 @@ func Test_challengergen(t *testing.T) { func Test_accessToken(t *testing.T) { o := OAuth{} - _, err := o.AccessToken() + _, err := o.AccessToken(context.Background()) if err == nil { t.Fatalf("No error when getting access token on empty structure") } @@ -69,7 +70,7 @@ func Test_accessToken(t *testing.T) { want := "test" expired := time.Now().Add(1 * time.Hour) o = OAuth{token: &tokenLock{t: &tokenRefresher{Token: Token{Access: want, ExpiredTimestamp: expired}}}} - got, err := o.AccessToken() + got, err := o.AccessToken(context.Background()) if err != nil { t.Fatalf("Got error when getting access token on non-empty structure: %v", err) } @@ -80,8 +81,8 @@ func Test_accessToken(t *testing.T) { // Set the tokens as expired o.SetTokenExpired() - // We should get an error because expired and no refresh token - _, err = o.AccessToken() + // We should not get an error because expired and no refresh token + _, err = o.AccessToken(context.Background()) if err == nil { t.Fatal("Got no error when getting access token on non-empty structure and expired") } @@ -90,7 +91,7 @@ func Test_accessToken(t *testing.T) { // Now we internally update the refresh function and refresh token, we should get new tokens refresh := "refresh" o.token.t.Refresh = refresh - o.token.t.Refresher = func(refreshToken string) (*TokenResponse, time.Time, error) { + o.token.t.Refresher = func(ctx context.Context, refreshToken string) (*TokenResponse, time.Time, error) { if refreshToken != refresh { t.Fatalf("Passed refresh token to refresher not equal to updated refresh token, got: %v, want: %v", refreshToken, refresh) } @@ -99,7 +100,7 @@ func Test_accessToken(t *testing.T) { return r, expired, nil } - got, err = o.AccessToken() + got, err = o.AccessToken(context.Background()) if err != nil { t.Fatalf("Got error when getting access token on non-empty expired structure and with a 'valid' refresh token: %v", err) } diff --git a/internal/oauth/token.go b/internal/oauth/token.go index 58d6136..a3d0d3b 100644 --- a/internal/oauth/token.go +++ b/internal/oauth/token.go @@ -1,6 +1,7 @@ package oauth import ( + "context" "fmt" "sync" "time" @@ -42,7 +43,7 @@ type Token struct { type tokenRefresher struct { Token // Refresher is the function that refreshes the token - Refresher func(string) (*TokenResponse, time.Time, error) + Refresher func(context.Context, string) (*TokenResponse, time.Time, error) } // tokenLock is a wrapper around token that protects it with a lock @@ -58,7 +59,7 @@ type tokenLock struct { // Access gets the OAuth access token used for contacting the server API // It returns the access token as a string, possibly obtained fresh using the refresher // If the token cannot be obtained, an error is returned and the token is an empty string. -func (l *tokenLock) Access() (string, error) { +func (l *tokenLock) Access(ctx context.Context) (string, error) { log.Logger.Debugf("Getting access token") l.mu.Lock() defer l.mu.Unlock() @@ -78,7 +79,7 @@ func (l *tokenLock) Access() (string, error) { } // Otherwise refresh and then later return the access token if we are successful - tr, s, err := l.t.Refresher(l.t.Refresh) + tr, s, err := l.t.Refresher(ctx, l.t.Refresh) if err != nil { log.Logger.Debugf("Got a refresh token error: %v", err) // We have failed to ensure the tokens due to refresh not working |
