summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2023-04-12 22:51:30 +0200
committerJeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com>2023-09-25 09:43:37 +0200
commit2898723cbe9c2bd65995dc22d080c3067ebdf4b7 (patch)
tree3dfa0a8c60973d2ccb78abd39422567fb1e19cdb
parent056e0c17a72cf9000c0ed1771c1ab38449c726fd (diff)
OAuth: Pass a context around
-rw-r--r--internal/oauth/oauth.go39
-rw-r--r--internal/oauth/oauth_test.go13
-rw-r--r--internal/oauth/token.go7
3 files changed, 31 insertions, 28 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index f2e7719..f9bf164 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -123,12 +123,12 @@ type exchangeSession struct {
// AccessToken gets the OAuth access token used for contacting the server API
// It returns the access token as a string, possibly obtained fresh using the Refresh Token
// If the token cannot be obtained, an error is returned and the token is an empty string.
-func (oauth *OAuth) AccessToken() (string, error) {
+func (oauth *OAuth) AccessToken(ctx context.Context) (string, error) {
tl := oauth.token
if tl == nil {
return "", errors.New("No token structure available")
}
- return tl.Access()
+ return tl.Access(ctx)
}
// setupListener sets up an OAuth listener
@@ -147,7 +147,7 @@ func (oauth *OAuth) setupListener() error {
// tokensWithCallback gets the OAuth tokens using a local web server
// If it was unsuccessful it returns an error.
-func (oauth *OAuth) tokensWithCallback() error {
+func (oauth *OAuth) tokensWithCallback(ctx context.Context) error {
if oauth.session.Listener == nil {
return errors.New("failed getting tokens with callback: no listener")
}
@@ -159,7 +159,7 @@ func (oauth *OAuth) tokensWithCallback() error {
// A bit overkill maybe for a local server but good to define anyways
ReadHeaderTimeout: 60 * time.Second,
}
- defer s.Shutdown(context.Background()) //nolint:errcheck
+ defer s.Shutdown(ctx) //nolint:errcheck
// Use a sync.Once to only handle one request up until we shutdown the server
var once sync.Once
@@ -174,7 +174,12 @@ func (oauth *OAuth) tokensWithCallback() error {
oauth.session.ErrChan <- errors.WrapPrefix(err, "failed getting tokens with callback", 0)
}
}()
- return <-oauth.session.ErrChan
+ select {
+ case err := <-oauth.session.ErrChan:
+ return err
+ case <-ctx.Done():
+ return errors.WrapPrefix(context.Canceled, "stopped oauth server", 0)
+ }
}
// tokenResponse fills the OAuth token response structure by the response
@@ -221,7 +226,7 @@ func (oauth *OAuth) Token() Token {
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
// If it was unsuccessful it returns an error.
-func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
+func (oauth *OAuth) tokensWithAuthCode(ctx context.Context, authCode string) error {
// Make sure the verifier is set as the parameter
// so that the server can verify that we are the actual owner of the authorization code
u := oauth.TokenURL
@@ -245,7 +250,7 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
now := time.Now()
// We are sure that we have a http client because we have initialized it when starting the exchange
- _, body, err := oauth.httpClient.PostWithOpts(u, opts)
+ _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts)
if err != nil {
return err
}
@@ -274,7 +279,7 @@ func (oauth *OAuth) UpdateTokens(t Token) {
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
// If it was unsuccessful it returns an error.
-func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error) {
+func (oauth *OAuth) refreshResponse(ctx context.Context, r string) (*TokenResponse, time.Time, error) {
u := oauth.TokenURL
if oauth.token == nil {
return nil, time.Time{}, errors.New("No oauth token structure in refresh")
@@ -298,7 +303,7 @@ func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error)
oauth.httpClient = httpw.NewClient()
}
- _, body, err := oauth.httpClient.PostWithOpts(u, opts)
+ _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts)
if err != nil {
return nil, time.Time{}, err
}
@@ -394,7 +399,7 @@ func (s *exchangeSession) Authcode(url *url.URL) (string, error) {
// tokenHandler gets the tokens using the authorization code that is obtained through the url
// This function is called by the http handler and returns an error if the tokens cannot be obtained
-func (oauth *OAuth) tokenHandler(url *url.URL) error {
+func (oauth *OAuth) tokenHandler(ctx context.Context, url *url.URL) error {
// Get the authorization code
c, err := oauth.session.Authcode(url)
if err != nil {
@@ -402,14 +407,15 @@ func (oauth *OAuth) tokenHandler(url *url.URL) error {
}
// Now that we have obtained the authorization code, we can move to the next step:
// Obtaining the access and refresh tokens
- return oauth.tokensWithAuthCode(c)
+ return oauth.tokensWithAuthCode(ctx, c)
}
// Handler is the function used to get the OAuth tokens using an authorization code callback
// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
// It sends an error to the session channel (can be nil)
func (oauth *OAuth) Handler(w http.ResponseWriter, req *http.Request) {
- err := oauth.tokenHandler(req.URL)
+ // TODO: should this be something else than context background?
+ err := oauth.tokenHandler(context.Background(), req.URL)
if err != nil {
_ = writeResponseHTML(
w,
@@ -507,17 +513,12 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s
// Exchange starts the OAuth exchange by getting the tokens with the redirect callback
// If it was unsuccessful it returns an error.
-func (oauth *OAuth) Exchange() error {
+func (oauth *OAuth) Exchange(ctx context.Context) error {
// If there is no HTTP client defined, create a new one
if oauth.httpClient == nil {
oauth.httpClient = httpw.NewClient()
}
- return oauth.tokensWithCallback()
-}
-
-// Cancel cancels the existing OAuth server by sending a cancel error to the channel
-func (oauth *OAuth) Cancel() {
- oauth.session.ErrChan <- errors.Wrap(&CancelledCallbackError{}, 0)
+ return oauth.tokensWithCallback(ctx)
}
type CancelledCallbackError struct{}
diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go
index 8682e24..60ce5c7 100644
--- a/internal/oauth/oauth_test.go
+++ b/internal/oauth/oauth_test.go
@@ -1,6 +1,7 @@
package oauth
import (
+ "context"
"encoding/json"
"fmt"
"net/url"
@@ -60,7 +61,7 @@ func Test_challengergen(t *testing.T) {
func Test_accessToken(t *testing.T) {
o := OAuth{}
- _, err := o.AccessToken()
+ _, err := o.AccessToken(context.Background())
if err == nil {
t.Fatalf("No error when getting access token on empty structure")
}
@@ -69,7 +70,7 @@ func Test_accessToken(t *testing.T) {
want := "test"
expired := time.Now().Add(1 * time.Hour)
o = OAuth{token: &tokenLock{t: &tokenRefresher{Token: Token{Access: want, ExpiredTimestamp: expired}}}}
- got, err := o.AccessToken()
+ got, err := o.AccessToken(context.Background())
if err != nil {
t.Fatalf("Got error when getting access token on non-empty structure: %v", err)
}
@@ -80,8 +81,8 @@ func Test_accessToken(t *testing.T) {
// Set the tokens as expired
o.SetTokenExpired()
- // We should get an error because expired and no refresh token
- _, err = o.AccessToken()
+ // We should not get an error because expired and no refresh token
+ _, err = o.AccessToken(context.Background())
if err == nil {
t.Fatal("Got no error when getting access token on non-empty structure and expired")
}
@@ -90,7 +91,7 @@ func Test_accessToken(t *testing.T) {
// Now we internally update the refresh function and refresh token, we should get new tokens
refresh := "refresh"
o.token.t.Refresh = refresh
- o.token.t.Refresher = func(refreshToken string) (*TokenResponse, time.Time, error) {
+ o.token.t.Refresher = func(ctx context.Context, refreshToken string) (*TokenResponse, time.Time, error) {
if refreshToken != refresh {
t.Fatalf("Passed refresh token to refresher not equal to updated refresh token, got: %v, want: %v", refreshToken, refresh)
}
@@ -99,7 +100,7 @@ func Test_accessToken(t *testing.T) {
return r, expired, nil
}
- got, err = o.AccessToken()
+ got, err = o.AccessToken(context.Background())
if err != nil {
t.Fatalf("Got error when getting access token on non-empty expired structure and with a 'valid' refresh token: %v", err)
}
diff --git a/internal/oauth/token.go b/internal/oauth/token.go
index 58d6136..a3d0d3b 100644
--- a/internal/oauth/token.go
+++ b/internal/oauth/token.go
@@ -1,6 +1,7 @@
package oauth
import (
+ "context"
"fmt"
"sync"
"time"
@@ -42,7 +43,7 @@ type Token struct {
type tokenRefresher struct {
Token
// Refresher is the function that refreshes the token
- Refresher func(string) (*TokenResponse, time.Time, error)
+ Refresher func(context.Context, string) (*TokenResponse, time.Time, error)
}
// tokenLock is a wrapper around token that protects it with a lock
@@ -58,7 +59,7 @@ type tokenLock struct {
// Access gets the OAuth access token used for contacting the server API
// It returns the access token as a string, possibly obtained fresh using the refresher
// If the token cannot be obtained, an error is returned and the token is an empty string.
-func (l *tokenLock) Access() (string, error) {
+func (l *tokenLock) Access(ctx context.Context) (string, error) {
log.Logger.Debugf("Getting access token")
l.mu.Lock()
defer l.mu.Unlock()
@@ -78,7 +79,7 @@ func (l *tokenLock) Access() (string, error) {
}
// Otherwise refresh and then later return the access token if we are successful
- tr, s, err := l.t.Refresher(l.t.Refresh)
+ tr, s, err := l.t.Refresher(ctx, l.t.Refresh)
if err != nil {
log.Logger.Debugf("Got a refresh token error: %v", err)
// We have failed to ensure the tokens due to refresh not working