summaryrefslogtreecommitdiff
path: root/internal/oauth/oauth.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/oauth/oauth.go')
-rw-r--r--internal/oauth/oauth.go39
1 files changed, 20 insertions, 19 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index f2e7719..f9bf164 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -123,12 +123,12 @@ type exchangeSession struct {
// AccessToken gets the OAuth access token used for contacting the server API
// It returns the access token as a string, possibly obtained fresh using the Refresh Token
// If the token cannot be obtained, an error is returned and the token is an empty string.
-func (oauth *OAuth) AccessToken() (string, error) {
+func (oauth *OAuth) AccessToken(ctx context.Context) (string, error) {
tl := oauth.token
if tl == nil {
return "", errors.New("No token structure available")
}
- return tl.Access()
+ return tl.Access(ctx)
}
// setupListener sets up an OAuth listener
@@ -147,7 +147,7 @@ func (oauth *OAuth) setupListener() error {
// tokensWithCallback gets the OAuth tokens using a local web server
// If it was unsuccessful it returns an error.
-func (oauth *OAuth) tokensWithCallback() error {
+func (oauth *OAuth) tokensWithCallback(ctx context.Context) error {
if oauth.session.Listener == nil {
return errors.New("failed getting tokens with callback: no listener")
}
@@ -159,7 +159,7 @@ func (oauth *OAuth) tokensWithCallback() error {
// A bit overkill maybe for a local server but good to define anyways
ReadHeaderTimeout: 60 * time.Second,
}
- defer s.Shutdown(context.Background()) //nolint:errcheck
+ defer s.Shutdown(ctx) //nolint:errcheck
// Use a sync.Once to only handle one request up until we shutdown the server
var once sync.Once
@@ -174,7 +174,12 @@ func (oauth *OAuth) tokensWithCallback() error {
oauth.session.ErrChan <- errors.WrapPrefix(err, "failed getting tokens with callback", 0)
}
}()
- return <-oauth.session.ErrChan
+ select {
+ case err := <-oauth.session.ErrChan:
+ return err
+ case <-ctx.Done():
+ return errors.WrapPrefix(context.Canceled, "stopped oauth server", 0)
+ }
}
// tokenResponse fills the OAuth token response structure by the response
@@ -221,7 +226,7 @@ func (oauth *OAuth) Token() Token {
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
// If it was unsuccessful it returns an error.
-func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
+func (oauth *OAuth) tokensWithAuthCode(ctx context.Context, authCode string) error {
// Make sure the verifier is set as the parameter
// so that the server can verify that we are the actual owner of the authorization code
u := oauth.TokenURL
@@ -245,7 +250,7 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
now := time.Now()
// We are sure that we have a http client because we have initialized it when starting the exchange
- _, body, err := oauth.httpClient.PostWithOpts(u, opts)
+ _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts)
if err != nil {
return err
}
@@ -274,7 +279,7 @@ func (oauth *OAuth) UpdateTokens(t Token) {
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
// If it was unsuccessful it returns an error.
-func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error) {
+func (oauth *OAuth) refreshResponse(ctx context.Context, r string) (*TokenResponse, time.Time, error) {
u := oauth.TokenURL
if oauth.token == nil {
return nil, time.Time{}, errors.New("No oauth token structure in refresh")
@@ -298,7 +303,7 @@ func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error)
oauth.httpClient = httpw.NewClient()
}
- _, body, err := oauth.httpClient.PostWithOpts(u, opts)
+ _, body, err := oauth.httpClient.PostWithOpts(ctx, u, opts)
if err != nil {
return nil, time.Time{}, err
}
@@ -394,7 +399,7 @@ func (s *exchangeSession) Authcode(url *url.URL) (string, error) {
// tokenHandler gets the tokens using the authorization code that is obtained through the url
// This function is called by the http handler and returns an error if the tokens cannot be obtained
-func (oauth *OAuth) tokenHandler(url *url.URL) error {
+func (oauth *OAuth) tokenHandler(ctx context.Context, url *url.URL) error {
// Get the authorization code
c, err := oauth.session.Authcode(url)
if err != nil {
@@ -402,14 +407,15 @@ func (oauth *OAuth) tokenHandler(url *url.URL) error {
}
// Now that we have obtained the authorization code, we can move to the next step:
// Obtaining the access and refresh tokens
- return oauth.tokensWithAuthCode(c)
+ return oauth.tokensWithAuthCode(ctx, c)
}
// Handler is the function used to get the OAuth tokens using an authorization code callback
// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
// It sends an error to the session channel (can be nil)
func (oauth *OAuth) Handler(w http.ResponseWriter, req *http.Request) {
- err := oauth.tokenHandler(req.URL)
+ // TODO: should this be something else than context background?
+ err := oauth.tokenHandler(context.Background(), req.URL)
if err != nil {
_ = writeResponseHTML(
w,
@@ -507,17 +513,12 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s
// Exchange starts the OAuth exchange by getting the tokens with the redirect callback
// If it was unsuccessful it returns an error.
-func (oauth *OAuth) Exchange() error {
+func (oauth *OAuth) Exchange(ctx context.Context) error {
// If there is no HTTP client defined, create a new one
if oauth.httpClient == nil {
oauth.httpClient = httpw.NewClient()
}
- return oauth.tokensWithCallback()
-}
-
-// Cancel cancels the existing OAuth server by sending a cancel error to the channel
-func (oauth *OAuth) Cancel() {
- oauth.session.ErrChan <- errors.Wrap(&CancelledCallbackError{}, 0)
+ return oauth.tokensWithCallback(ctx)
}
type CancelledCallbackError struct{}