summaryrefslogtreecommitdiff
path: root/internal/oauth/oauth.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/oauth/oauth.go')
-rw-r--r--internal/oauth/oauth.go159
1 files changed, 74 insertions, 85 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 6d63235..84ecdc4 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -80,9 +80,6 @@ type OAuth struct {
// ISS indicates the issuer indentifier of the authorization server as defined in RFC 9207
ISS string `json:"iss"`
- // Token is where the access and refresh tokens are stored along with the timestamps
- Token OAuthToken `json:"token"`
-
// BaseAuthorizationURL is the URL where authorization should take place
BaseAuthorizationURL string `json:"base_authorization_url"`
@@ -91,6 +88,9 @@ type OAuth struct {
// session is the internal in progress OAuth session
session OAuthExchangeSession `json:"-"`
+
+ // Token is where the access and refresh tokens are stored along with the timestamps
+ token OAuthToken `json:"-"`
}
// OAuthExchangeSession is a structure that gets passed to the callback for easy access to the current state.
@@ -120,22 +120,40 @@ type OAuthExchangeSession struct {
Listener net.Listener
}
-// OAuthToken is a structure that defines the json format for /.well-known/vpn-user-portal".
-type OAuthToken struct {
- // Access is the access token returned by the server
- Access string `json:"access_token"`
+// AccessToken gets the OAuth access token used for contacting the server API
+// It returns the access token as a string, possibly obtained fresh using the Refresh Token
+// If the token cannot be obtained, an error is returned and the token is an empty string.
+func (oauth *OAuth) AccessToken() (string, error) {
+ errorMessage := "failed getting access token"
+ tokens := oauth.token
- // Refresh token is the refresh token returned by the server
- Refresh string `json:"refresh_token"`
+ // We have tokens...
+ // The tokens are not expired yet
+ // So they should be valid, re-login not needed
+ if !tokens.Expired() {
+ return tokens.access, nil
+ }
- // Type indicates which type of tokens we have
- Type string `json:"token_type"`
+ // Check if refresh is even possible by doing a simple check if the refresh token is empty
+ // This is not needed but reduces API calls to the server
+ if tokens.refresh == "" {
+ return "", types.NewWrappedError(errorMessage, &OAuthTokensInvalidError{Cause: "no refresh token is present"})
+ }
- // Expires is the expires time returned by the server
- Expires int64 `json:"expires_in"`
+ // Otherwise refresh and then later return the access token if we are successful
+ refreshErr := oauth.tokensWithRefresh()
+ if refreshErr != nil {
+ // We have failed to ensure the tokens due to refresh not working
+ return "", types.NewWrappedError(
+ errorMessage,
+ &OAuthTokensInvalidError{
+ Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr),
+ },
+ )
+ }
- // ExpiredTimestamp is the Expires field but converted to a Go timestamp
- ExpiredTimestamp time.Time `json:"expires_in_timestamp"`
+ // We have obtained new tokens with refresh
+ return tokens.access, nil
}
// setupListener sets up an OAuth listener
@@ -173,6 +191,40 @@ func (oauth *OAuth) tokensWithCallback() error {
return oauth.session.CallbackError
}
+// fillToken fills the OAuth token structure by the response
+// It calculates the expired timestamp by having a 'startTime' passed to it
+// The URL that is input here is used for additional context.
+func (oauth *OAuth) fillToken(response []byte, startTime time.Time, url string) error {
+ responseStructure := OAuthTokenResponse{}
+
+ jsonErr := json.Unmarshal(response, &responseStructure)
+ if jsonErr != nil {
+ return types.NewWrappedError(
+ "failed filling OAuth tokens",
+ &httpw.HTTPParseJSONError{URL: url, Body: string(response), Err: jsonErr},
+ )
+ }
+
+ internalStructure := OAuthToken{}
+ internalStructure.expiredTimestamp = startTime.Add(
+ time.Second * time.Duration(responseStructure.Expires),
+ )
+ internalStructure.access = responseStructure.Access
+ internalStructure.refresh = responseStructure.Refresh
+ oauth.token = internalStructure
+ return nil
+}
+
+// SetTokenExpired marks the tokens as expired by setting the expired timestamp to the current time.
+func (oauth *OAuth) SetTokenExpired() {
+ oauth.token.expiredTimestamp = time.Now()
+}
+
+// SetTokenRenew sets the tokens for renewal by completely clearing the structure.
+func (oauth *OAuth) SetTokenRenew() {
+ oauth.token = OAuthToken{}
+}
+
// tokensWithAuthCode gets the access and refresh tokens using the authorization code
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
@@ -205,31 +257,13 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
return types.NewWrappedError(errorMessage, bodyErr)
}
- tokenStructure := OAuthToken{}
-
- jsonErr := json.Unmarshal(body, &tokenStructure)
-
- if jsonErr != nil {
- return types.NewWrappedError(
- errorMessage,
- &httpw.HTTPParseJSONError{URL: reqURL, Body: string(body), Err: jsonErr},
- )
+ fillErr := oauth.fillToken(body, currentTime, reqURL)
+ if fillErr != nil {
+ return types.NewWrappedError(errorMessage, fillErr)
}
-
- tokenStructure.ExpiredTimestamp = currentTime.Add(
- time.Second * time.Duration(tokenStructure.Expires),
- )
- oauth.Token = tokenStructure
return nil
}
-// isTokensExpired returns if the OAuth tokens are expired using the expired timestamp.
-func (oauth *OAuth) isTokensExpired() bool {
- expiredTime := oauth.Token.ExpiredTimestamp
- currentTime := time.Now()
- return !currentTime.Before(expiredTime)
-}
-
// tokensWithRefresh gets the access and refresh tokens with a previously received refresh token
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
@@ -238,7 +272,7 @@ func (oauth *OAuth) tokensWithRefresh() error {
errorMessage := "failed getting tokens with the refresh token"
reqURL := oauth.TokenURL
data := url.Values{
- "refresh_token": {oauth.Token.Refresh},
+ "refresh_token": {oauth.token.refresh},
"grant_type": {"refresh_token"},
}
headers := http.Header{
@@ -251,20 +285,10 @@ func (oauth *OAuth) tokensWithRefresh() error {
return types.NewWrappedError(errorMessage, bodyErr)
}
- tokenStructure := OAuthToken{}
- jsonErr := json.Unmarshal(body, &tokenStructure)
-
- if jsonErr != nil {
- return types.NewWrappedError(
- errorMessage,
- &httpw.HTTPParseJSONError{URL: reqURL, Body: string(body), Err: jsonErr},
- )
+ fillErr := oauth.fillToken(body, currentTime, reqURL)
+ if fillErr != nil {
+ return types.NewWrappedError(errorMessage, fillErr)
}
-
- tokenStructure.ExpiredTimestamp = currentTime.Add(
- time.Second * time.Duration(tokenStructure.Expires),
- )
- oauth.Token = tokenStructure
return nil
}
@@ -506,41 +530,6 @@ func (oauth *OAuth) Cancel() {
}
}
-// EnsureTokens makes sure the OAuth tokens are still valid
-// if this cannot be guaranteed, it returns an error.
-func (oauth *OAuth) EnsureTokens() error {
- errorMessage := "failed ensuring OAuth tokens"
- // Access Token or Refresh Tokens empty, we can not ensure the tokens
- if oauth.Token.Access == "" && oauth.Token.Refresh == "" {
- return types.NewWrappedError(
- errorMessage,
- &OAuthTokensInvalidError{Cause: "tokens are empty"},
- )
- }
-
- // We have tokens...
- // The tokens are not expired yet
- // So they should be valid, re-login not needed
- if !oauth.isTokensExpired() {
- return nil
- }
-
- // Otherwise try to refresh them and return if successful
- refreshErr := oauth.tokensWithRefresh()
- // We have obtained new tokens with refresh
- if refreshErr != nil {
- // We have failed to ensure the tokens due to refresh not working
- return types.NewWrappedError(
- errorMessage,
- &OAuthTokensInvalidError{
- Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr),
- },
- )
- }
-
- return nil
-}
-
type OAuthCancelledCallbackError struct{}
func (e *OAuthCancelledCallbackError) Error() string {