summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-12-15 13:43:55 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-12-21 18:27:33 +0100
commit0357b97b15c62147557b49cd50ee690aba93c06d (patch)
tree43629fc2cdb41fed3d4e8a9cf26bf12c9081bd9b
parentb7734f0dd4d4b2f8be0db90267638ed22807ba81 (diff)
OAuth: Use a mutex to protect the token structure
-rw-r--r--internal/oauth/oauth.go99
-rw-r--r--internal/oauth/token.go94
2 files changed, 140 insertions, 53 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 52a64b6..e3bbd6e 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -88,7 +88,8 @@ type OAuth struct {
session ExchangeSession
// Token is where the access and refresh tokens are stored along with the timestamps
- token Token
+ // It is protected by a lock
+ token *tokenLock
}
// ExchangeSession is a structure that gets passed to the callback for easy access to the current state.
@@ -119,31 +120,11 @@ type ExchangeSession struct {
// It returns the access token as a string, possibly obtained fresh using the Refresh Token
// If the token cannot be obtained, an error is returned and the token is an empty string.
func (oauth *OAuth) AccessToken() (string, error) {
- t := oauth.token
-
- // We have tokens...
- // The tokens are not expired yet
- // So they should be valid, re-login not needed
- if !t.Expired() {
- return t.access, nil
- }
-
- // Check if refresh is even possible by doing a simple check if the refresh token is empty
- // This is not needed but reduces API calls to the server
- if t.refresh == "" {
- return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0)
- }
-
- // Otherwise refresh and then later return the access token if we are successful
- err := oauth.tokensWithRefresh()
- if err != nil {
- // We have failed to ensure the tokens due to refresh not working
- return "", errors.Wrap(
- &TokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", err)}, 0)
+ tl := oauth.token
+ if tl == nil {
+ return "", errors.New("No token structure available")
}
-
- // We have obtained new tokens with refresh
- return t.access, nil
+ return tl.Access()
}
// setupListener sets up an OAuth listener
@@ -188,33 +169,35 @@ func (oauth *OAuth) tokensWithCallback() error {
return <-oauth.session.ErrChan
}
-// fillToken fills the OAuth token structure by the response
-// It calculates the expired timestamp by having a 'startTime' passed to it
-// The URL that is input here is used for additional context.
-func (oauth *OAuth) fillToken(response []byte, startTime time.Time, url string) error {
+// tokenResponse fills the OAuth token response structure by the response
+// The URL that is input here is used for additional context
+// It returns this structure and an error if there is one
+func (oauth *OAuth) tokenResponse(response []byte, url string) (*TokenResponse, error) {
+ if oauth.token == nil {
+ return nil, errors.New("No oauth structure when filling token")
+ }
res := TokenResponse{}
err := json.Unmarshal(response, &res)
if err != nil {
- return errors.WrapPrefix(err, "failed filling OAuth tokens from "+url, 0)
+ return nil, errors.WrapPrefix(err, "failed filling OAuth tokens from "+url, 0)
}
- oauth.token = Token{
- access: res.Access,
- refresh: res.Refresh,
- expiredTimestamp: startTime.Add(time.Second * time.Duration(res.Expires)),
- }
- return nil
+ return &res, nil
}
// SetTokenExpired marks the tokens as expired by setting the expired timestamp to the current time.
func (oauth *OAuth) SetTokenExpired() {
- oauth.token.expiredTimestamp = time.Now()
+ if oauth.token != nil {
+ oauth.token.SetExpired()
+ }
}
// SetTokenRenew sets the tokens for renewal by completely clearing the structure.
func (oauth *OAuth) SetTokenRenew() {
- oauth.token = Token{}
+ if oauth.token != nil {
+ oauth.token.Clear()
+ }
}
// tokensWithAuthCode gets the access and refresh tokens using the authorization code
@@ -248,17 +231,30 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
return err
}
- return oauth.fillToken(body, now, u)
+ tr, err := oauth.tokenResponse(body, u)
+ if err != nil {
+ return err
+ }
+ if tr == nil {
+ return errors.New("No token response after authorization code")
+ }
+
+ oauth.token.Update(*tr, now)
+ return nil
}
-// tokensWithRefresh gets the access and refresh tokens with a previously received refresh token
+// refreshResponse gets the refresh token response with a refresh token
+// This response contains the access and refresh tokens, together with a timestamp
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
// If it was unsuccessful it returns an error.
-func (oauth *OAuth) tokensWithRefresh() error {
+func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error) {
u := oauth.TokenURL
+ if oauth.token == nil {
+ return nil, time.Time{}, errors.New("No oauth token structure in refresh")
+ }
data := url.Values{
- "refresh_token": {oauth.token.refresh},
+ "refresh_token": {r},
"grant_type": {"refresh_token"},
}
h := http.Header{
@@ -268,10 +264,11 @@ func (oauth *OAuth) tokensWithRefresh() error {
now := time.Now()
_, body, err := httpw.PostWithOpts(u, opts)
if err != nil {
- return err
+ return nil, time.Time{}, err
}
- return oauth.fillToken(body, now, u)
+ tr, err := oauth.tokenResponse(body, u)
+ return tr, now, err
}
// responseTemplate is the HTML template for the OAuth authorized response
@@ -329,6 +326,8 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro
return t.Execute(w, oauthResponseHTML{Title: title, Message: message})
}
+// Authcode gets the authorization code from the url
+// It returns the code and an error if there is one
func (s *ExchangeSession) Authcode(url *url.URL) (string, error) {
// ISS: https://www.rfc-editor.org/rfc/rfc9207.html
// TODO: Make this a required parameter in the future
@@ -358,7 +357,9 @@ func (s *ExchangeSession) Authcode(url *url.URL) (string, error) {
return code, nil
}
-func (oauth *OAuth) TokenHandler(url *url.URL) error {
+// tokenHandler gets the tokens using the authorization code that is obtained through the url
+// This function is called by the http handler and returns an error if the tokens cannot be obtained
+func (oauth *OAuth) tokenHandler(url *url.URL) error {
// Get the authorization code
c, err := oauth.session.Authcode(url)
if err != nil {
@@ -369,10 +370,11 @@ func (oauth *OAuth) TokenHandler(url *url.URL) error {
return oauth.tokensWithAuthCode(c)
}
-// Callback is the public function used to get the OAuth tokens using an authorization code callback
+// Handler is the function used to get the OAuth tokens using an authorization code callback
// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1
+// It sends an error to the session channel (can be nil)
func (oauth *OAuth) Handler(w http.ResponseWriter, req *http.Request) {
- err := oauth.TokenHandler(req.URL)
+ err := oauth.tokenHandler(req.URL)
if err != nil {
_ = writeResponseHTML(
w,
@@ -418,6 +420,9 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s
return "", errors.WrapPrefix(err, "genState error", 0)
}
+ // Fill the oauth tokens
+ oauth.token = &tokenLock{t: &token{Refresher: oauth.refreshResponse}}
+
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
oauth.session = ExchangeSession{
ClientID: name,
diff --git a/internal/oauth/token.go b/internal/oauth/token.go
index 31c2c74..855677c 100644
--- a/internal/oauth/token.go
+++ b/internal/oauth/token.go
@@ -1,6 +1,12 @@
package oauth
-import "time"
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/go-errors/errors"
+)
// TokenResponse defines the OAuth response from the server that includes the tokens.
type TokenResponse struct {
@@ -17,8 +23,8 @@ type TokenResponse struct {
Expires int64 `json:"expires_in"`
}
-// Token is a structure that contains our access and refresh tokens and a timestamp when they expire.
-type Token struct {
+// token is a structure that contains our access and refresh tokens and a timestamp when they expire.
+type token struct {
// Access is the access token returned by the server
access string
@@ -27,10 +33,86 @@ type Token struct {
// ExpiredTimestamp is the Expires field but converted to a Go timestamp
expiredTimestamp time.Time
+
+ // Refresher is the function that refreshes the token
+ Refresher func(string) (*TokenResponse, time.Time, error)
+}
+
+// tokenLock is a wrapper around token that protects it with a lock
+type tokenLock struct {
+ // Protects t
+ mu sync.Mutex
+
+ // The token fields protected by the lock
+ t *token
+}
+
+// Access gets the OAuth access token used for contacting the server API
+// It returns the access token as a string, possibly obtained fresh using the refresher
+// If the token cannot be obtained, an error is returned and the token is an empty string.
+func (l *tokenLock) Access() (string, error) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+
+ // The tokens are not expired yet
+ // So they should be valid, re-login not neede
+ if !l.expired() {
+ return l.t.access, nil
+ }
+
+ // Check if refresh is even possible by doing a simple check if the refresh token is empty
+ // This is not needed but reduces API calls to the server
+ if l.t.refresh == "" {
+ return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0)
+ }
+
+ // Otherwise refresh and then later return the access token if we are successful
+ tr, s, err := l.t.Refresher(l.t.refresh)
+ if err != nil {
+ // We have failed to ensure the tokens due to refresh not working
+ return "", errors.Wrap(
+ &TokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", err)}, 0)
+ }
+ if tr == nil {
+ return "", errors.New("No token response after refreshing")
+ }
+ l.updateInternal(*tr, s)
+ return l.t.access, nil
+}
+
+// Clear completely clears the token structure
+// This is useful for forcing re-authorization
+func (l *tokenLock) Clear() {
+ l.mu.Lock()
+ l.t = &token{}
+ l.mu.Unlock()
+}
+
+// updateInternal updates the structure using the response without locking
+func (l *tokenLock) updateInternal(r TokenResponse, s time.Time) {
+ l.t.access = r.Access
+ l.t.refresh = r.Refresh
+ l.t.expiredTimestamp = s.Add(time.Second * time.Duration(r.Expires))
+}
+
+// Update updates the structure usign the response and locks
+func (l *tokenLock) Update(r TokenResponse, s time.Time) {
+ l.mu.Lock()
+ l.updateInternal(r, s)
+ l.mu.Unlock()
+}
+
+// SetExpired overrides the timestamp to the current time
+// This marks the tokens as expired
+func (l *tokenLock) SetExpired() {
+ l.mu.Lock()
+ l.t.expiredTimestamp = time.Now()
+ l.mu.Unlock()
}
-// Expired checks if the access token is expired.
-func (tokens *Token) Expired() bool {
+// expired checks if the access token is expired.
+// This is only called internally and thus does not lock
+func (l *tokenLock) expired() bool {
now := time.Now()
- return !now.Before(tokens.expiredTimestamp)
+ return !now.Before(l.t.expiredTimestamp)
}