diff options
| -rw-r--r-- | internal/oauth/oauth.go | 99 | ||||
| -rw-r--r-- | internal/oauth/token.go | 94 |
2 files changed, 140 insertions, 53 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 52a64b6..e3bbd6e 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -88,7 +88,8 @@ type OAuth struct { session ExchangeSession // Token is where the access and refresh tokens are stored along with the timestamps - token Token + // It is protected by a lock + token *tokenLock } // ExchangeSession is a structure that gets passed to the callback for easy access to the current state. @@ -119,31 +120,11 @@ type ExchangeSession struct { // It returns the access token as a string, possibly obtained fresh using the Refresh Token // If the token cannot be obtained, an error is returned and the token is an empty string. func (oauth *OAuth) AccessToken() (string, error) { - t := oauth.token - - // We have tokens... - // The tokens are not expired yet - // So they should be valid, re-login not needed - if !t.Expired() { - return t.access, nil - } - - // Check if refresh is even possible by doing a simple check if the refresh token is empty - // This is not needed but reduces API calls to the server - if t.refresh == "" { - return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0) - } - - // Otherwise refresh and then later return the access token if we are successful - err := oauth.tokensWithRefresh() - if err != nil { - // We have failed to ensure the tokens due to refresh not working - return "", errors.Wrap( - &TokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", err)}, 0) + tl := oauth.token + if tl == nil { + return "", errors.New("No token structure available") } - - // We have obtained new tokens with refresh - return t.access, nil + return tl.Access() } // setupListener sets up an OAuth listener @@ -188,33 +169,35 @@ func (oauth *OAuth) tokensWithCallback() error { return <-oauth.session.ErrChan } -// fillToken fills the OAuth token structure by the response -// It calculates the expired timestamp by having a 'startTime' passed to it -// The URL that is input here is used for additional context. -func (oauth *OAuth) fillToken(response []byte, startTime time.Time, url string) error { +// tokenResponse fills the OAuth token response structure by the response +// The URL that is input here is used for additional context +// It returns this structure and an error if there is one +func (oauth *OAuth) tokenResponse(response []byte, url string) (*TokenResponse, error) { + if oauth.token == nil { + return nil, errors.New("No oauth structure when filling token") + } res := TokenResponse{} err := json.Unmarshal(response, &res) if err != nil { - return errors.WrapPrefix(err, "failed filling OAuth tokens from "+url, 0) + return nil, errors.WrapPrefix(err, "failed filling OAuth tokens from "+url, 0) } - oauth.token = Token{ - access: res.Access, - refresh: res.Refresh, - expiredTimestamp: startTime.Add(time.Second * time.Duration(res.Expires)), - } - return nil + return &res, nil } // SetTokenExpired marks the tokens as expired by setting the expired timestamp to the current time. func (oauth *OAuth) SetTokenExpired() { - oauth.token.expiredTimestamp = time.Now() + if oauth.token != nil { + oauth.token.SetExpired() + } } // SetTokenRenew sets the tokens for renewal by completely clearing the structure. func (oauth *OAuth) SetTokenRenew() { - oauth.token = Token{} + if oauth.token != nil { + oauth.token.Clear() + } } // tokensWithAuthCode gets the access and refresh tokens using the authorization code @@ -248,17 +231,30 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error { return err } - return oauth.fillToken(body, now, u) + tr, err := oauth.tokenResponse(body, u) + if err != nil { + return err + } + if tr == nil { + return errors.New("No token response after authorization code") + } + + oauth.token.Update(*tr, now) + return nil } -// tokensWithRefresh gets the access and refresh tokens with a previously received refresh token +// refreshResponse gets the refresh token response with a refresh token +// This response contains the access and refresh tokens, together with a timestamp // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 // If it was unsuccessful it returns an error. -func (oauth *OAuth) tokensWithRefresh() error { +func (oauth *OAuth) refreshResponse(r string) (*TokenResponse, time.Time, error) { u := oauth.TokenURL + if oauth.token == nil { + return nil, time.Time{}, errors.New("No oauth token structure in refresh") + } data := url.Values{ - "refresh_token": {oauth.token.refresh}, + "refresh_token": {r}, "grant_type": {"refresh_token"}, } h := http.Header{ @@ -268,10 +264,11 @@ func (oauth *OAuth) tokensWithRefresh() error { now := time.Now() _, body, err := httpw.PostWithOpts(u, opts) if err != nil { - return err + return nil, time.Time{}, err } - return oauth.fillToken(body, now, u) + tr, err := oauth.tokenResponse(body, u) + return tr, now, err } // responseTemplate is the HTML template for the OAuth authorized response @@ -329,6 +326,8 @@ func writeResponseHTML(w http.ResponseWriter, title string, message string) erro return t.Execute(w, oauthResponseHTML{Title: title, Message: message}) } +// Authcode gets the authorization code from the url +// It returns the code and an error if there is one func (s *ExchangeSession) Authcode(url *url.URL) (string, error) { // ISS: https://www.rfc-editor.org/rfc/rfc9207.html // TODO: Make this a required parameter in the future @@ -358,7 +357,9 @@ func (s *ExchangeSession) Authcode(url *url.URL) (string, error) { return code, nil } -func (oauth *OAuth) TokenHandler(url *url.URL) error { +// tokenHandler gets the tokens using the authorization code that is obtained through the url +// This function is called by the http handler and returns an error if the tokens cannot be obtained +func (oauth *OAuth) tokenHandler(url *url.URL) error { // Get the authorization code c, err := oauth.session.Authcode(url) if err != nil { @@ -369,10 +370,11 @@ func (oauth *OAuth) TokenHandler(url *url.URL) error { return oauth.tokensWithAuthCode(c) } -// Callback is the public function used to get the OAuth tokens using an authorization code callback +// Handler is the function used to get the OAuth tokens using an authorization code callback // The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 +// It sends an error to the session channel (can be nil) func (oauth *OAuth) Handler(w http.ResponseWriter, req *http.Request) { - err := oauth.TokenHandler(req.URL) + err := oauth.tokenHandler(req.URL) if err != nil { _ = writeResponseHTML( w, @@ -418,6 +420,9 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s return "", errors.WrapPrefix(err, "genState error", 0) } + // Fill the oauth tokens + oauth.token = &tokenLock{t: &token{Refresher: oauth.refreshResponse}} + // Fill the struct with the necessary fields filled for the next call to getting the HTTP client oauth.session = ExchangeSession{ ClientID: name, diff --git a/internal/oauth/token.go b/internal/oauth/token.go index 31c2c74..855677c 100644 --- a/internal/oauth/token.go +++ b/internal/oauth/token.go @@ -1,6 +1,12 @@ package oauth -import "time" +import ( + "fmt" + "sync" + "time" + + "github.com/go-errors/errors" +) // TokenResponse defines the OAuth response from the server that includes the tokens. type TokenResponse struct { @@ -17,8 +23,8 @@ type TokenResponse struct { Expires int64 `json:"expires_in"` } -// Token is a structure that contains our access and refresh tokens and a timestamp when they expire. -type Token struct { +// token is a structure that contains our access and refresh tokens and a timestamp when they expire. +type token struct { // Access is the access token returned by the server access string @@ -27,10 +33,86 @@ type Token struct { // ExpiredTimestamp is the Expires field but converted to a Go timestamp expiredTimestamp time.Time + + // Refresher is the function that refreshes the token + Refresher func(string) (*TokenResponse, time.Time, error) +} + +// tokenLock is a wrapper around token that protects it with a lock +type tokenLock struct { + // Protects t + mu sync.Mutex + + // The token fields protected by the lock + t *token +} + +// Access gets the OAuth access token used for contacting the server API +// It returns the access token as a string, possibly obtained fresh using the refresher +// If the token cannot be obtained, an error is returned and the token is an empty string. +func (l *tokenLock) Access() (string, error) { + l.mu.Lock() + defer l.mu.Unlock() + + // The tokens are not expired yet + // So they should be valid, re-login not neede + if !l.expired() { + return l.t.access, nil + } + + // Check if refresh is even possible by doing a simple check if the refresh token is empty + // This is not needed but reduces API calls to the server + if l.t.refresh == "" { + return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0) + } + + // Otherwise refresh and then later return the access token if we are successful + tr, s, err := l.t.Refresher(l.t.refresh) + if err != nil { + // We have failed to ensure the tokens due to refresh not working + return "", errors.Wrap( + &TokensInvalidError{Cause: fmt.Sprintf("tokens failed refresh with error: %v", err)}, 0) + } + if tr == nil { + return "", errors.New("No token response after refreshing") + } + l.updateInternal(*tr, s) + return l.t.access, nil +} + +// Clear completely clears the token structure +// This is useful for forcing re-authorization +func (l *tokenLock) Clear() { + l.mu.Lock() + l.t = &token{} + l.mu.Unlock() +} + +// updateInternal updates the structure using the response without locking +func (l *tokenLock) updateInternal(r TokenResponse, s time.Time) { + l.t.access = r.Access + l.t.refresh = r.Refresh + l.t.expiredTimestamp = s.Add(time.Second * time.Duration(r.Expires)) +} + +// Update updates the structure usign the response and locks +func (l *tokenLock) Update(r TokenResponse, s time.Time) { + l.mu.Lock() + l.updateInternal(r, s) + l.mu.Unlock() +} + +// SetExpired overrides the timestamp to the current time +// This marks the tokens as expired +func (l *tokenLock) SetExpired() { + l.mu.Lock() + l.t.expiredTimestamp = time.Now() + l.mu.Unlock() } -// Expired checks if the access token is expired. -func (tokens *Token) Expired() bool { +// expired checks if the access token is expired. +// This is only called internally and thus does not lock +func (l *tokenLock) expired() bool { now := time.Now() - return !now.Before(tokens.expiredTimestamp) + return !now.Before(l.t.expiredTimestamp) } |
