summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-11-28 13:28:27 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-11-28 13:50:02 +0100
commit279c0de75629de5868c3ac1b3272a2850e6b62f7 (patch)
treeb01b764baca799fe952f01a25f1cf5e05ced8333 /internal
parent7bab6c76599fdfd34ea9bb064d871ed2be01d4c8 (diff)
OAuth: Refactor Token getting and do not save them in the config
This commit refactors getting the tokens into receiver methods. This means that functions do not have to call the cryptic "EnsureTokens" method. The receiver getter then already verifier whether or not the tokens could be obtained (and refreshes too). The downside is that some things are now private, so testing for invalid tokens needs to be done somewhere else. This needs another patch such that clients can save the tokens themselves using a keyring.
Diffstat (limited to 'internal')
-rw-r--r--internal/oauth/oauth.go159
-rw-r--r--internal/oauth/token.go37
-rw-r--r--internal/server/api.go8
-rw-r--r--internal/server/common.go23
4 files changed, 125 insertions, 102 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 6d63235..84ecdc4 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -80,9 +80,6 @@ type OAuth struct {
// ISS indicates the issuer indentifier of the authorization server as defined in RFC 9207
ISS string `json:"iss"`
- // Token is where the access and refresh tokens are stored along with the timestamps
- Token OAuthToken `json:"token"`
-
// BaseAuthorizationURL is the URL where authorization should take place
BaseAuthorizationURL string `json:"base_authorization_url"`
@@ -91,6 +88,9 @@ type OAuth struct {
// session is the internal in progress OAuth session
session OAuthExchangeSession `json:"-"`
+
+ // Token is where the access and refresh tokens are stored along with the timestamps
+ token OAuthToken `json:"-"`
}
// OAuthExchangeSession is a structure that gets passed to the callback for easy access to the current state.
@@ -120,22 +120,40 @@ type OAuthExchangeSession struct {
Listener net.Listener
}
-// OAuthToken is a structure that defines the json format for /.well-known/vpn-user-portal".
-type OAuthToken struct {
- // Access is the access token returned by the server
- Access string `json:"access_token"`
+// AccessToken gets the OAuth access token used for contacting the server API
+// It returns the access token as a string, possibly obtained fresh using the Refresh Token
+// If the token cannot be obtained, an error is returned and the token is an empty string.
+func (oauth *OAuth) AccessToken() (string, error) {
+ errorMessage := "failed getting access token"
+ tokens := oauth.token
- // Refresh token is the refresh token returned by the server
- Refresh string `json:"refresh_token"`
+ // We have tokens...
+ // The tokens are not expired yet
+ // So they should be valid, re-login not needed
+ if !tokens.Expired() {
+ return tokens.access, nil
+ }
- // Type indicates which type of tokens we have
- Type string `json:"token_type"`
+ // Check if refresh is even possible by doing a simple check if the refresh token is empty
+ // This is not needed but reduces API calls to the server
+ if tokens.refresh == "" {
+ return "", types.NewWrappedError(errorMessage, &OAuthTokensInvalidError{Cause: "no refresh token is present"})
+ }
- // Expires is the expires time returned by the server
- Expires int64 `json:"expires_in"`
+ // Otherwise refresh and then later return the access token if we are successful
+ refreshErr := oauth.tokensWithRefresh()
+ if refreshErr != nil {
+ // We have failed to ensure the tokens due to refresh not working
+ return "", types.NewWrappedError(
+ errorMessage,
+ &OAuthTokensInvalidError{
+ Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr),
+ },
+ )
+ }
- // ExpiredTimestamp is the Expires field but converted to a Go timestamp
- ExpiredTimestamp time.Time `json:"expires_in_timestamp"`
+ // We have obtained new tokens with refresh
+ return tokens.access, nil
}
// setupListener sets up an OAuth listener
@@ -173,6 +191,40 @@ func (oauth *OAuth) tokensWithCallback() error {
return oauth.session.CallbackError
}
+// fillToken fills the OAuth token structure by the response
+// It calculates the expired timestamp by having a 'startTime' passed to it
+// The URL that is input here is used for additional context.
+func (oauth *OAuth) fillToken(response []byte, startTime time.Time, url string) error {
+ responseStructure := OAuthTokenResponse{}
+
+ jsonErr := json.Unmarshal(response, &responseStructure)
+ if jsonErr != nil {
+ return types.NewWrappedError(
+ "failed filling OAuth tokens",
+ &httpw.HTTPParseJSONError{URL: url, Body: string(response), Err: jsonErr},
+ )
+ }
+
+ internalStructure := OAuthToken{}
+ internalStructure.expiredTimestamp = startTime.Add(
+ time.Second * time.Duration(responseStructure.Expires),
+ )
+ internalStructure.access = responseStructure.Access
+ internalStructure.refresh = responseStructure.Refresh
+ oauth.token = internalStructure
+ return nil
+}
+
+// SetTokenExpired marks the tokens as expired by setting the expired timestamp to the current time.
+func (oauth *OAuth) SetTokenExpired() {
+ oauth.token.expiredTimestamp = time.Now()
+}
+
+// SetTokenRenew sets the tokens for renewal by completely clearing the structure.
+func (oauth *OAuth) SetTokenRenew() {
+ oauth.token = OAuthToken{}
+}
+
// tokensWithAuthCode gets the access and refresh tokens using the authorization code
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
@@ -205,31 +257,13 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
return types.NewWrappedError(errorMessage, bodyErr)
}
- tokenStructure := OAuthToken{}
-
- jsonErr := json.Unmarshal(body, &tokenStructure)
-
- if jsonErr != nil {
- return types.NewWrappedError(
- errorMessage,
- &httpw.HTTPParseJSONError{URL: reqURL, Body: string(body), Err: jsonErr},
- )
+ fillErr := oauth.fillToken(body, currentTime, reqURL)
+ if fillErr != nil {
+ return types.NewWrappedError(errorMessage, fillErr)
}
-
- tokenStructure.ExpiredTimestamp = currentTime.Add(
- time.Second * time.Duration(tokenStructure.Expires),
- )
- oauth.Token = tokenStructure
return nil
}
-// isTokensExpired returns if the OAuth tokens are expired using the expired timestamp.
-func (oauth *OAuth) isTokensExpired() bool {
- expiredTime := oauth.Token.ExpiredTimestamp
- currentTime := time.Now()
- return !currentTime.Before(expiredTime)
-}
-
// tokensWithRefresh gets the access and refresh tokens with a previously received refresh token
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
@@ -238,7 +272,7 @@ func (oauth *OAuth) tokensWithRefresh() error {
errorMessage := "failed getting tokens with the refresh token"
reqURL := oauth.TokenURL
data := url.Values{
- "refresh_token": {oauth.Token.Refresh},
+ "refresh_token": {oauth.token.refresh},
"grant_type": {"refresh_token"},
}
headers := http.Header{
@@ -251,20 +285,10 @@ func (oauth *OAuth) tokensWithRefresh() error {
return types.NewWrappedError(errorMessage, bodyErr)
}
- tokenStructure := OAuthToken{}
- jsonErr := json.Unmarshal(body, &tokenStructure)
-
- if jsonErr != nil {
- return types.NewWrappedError(
- errorMessage,
- &httpw.HTTPParseJSONError{URL: reqURL, Body: string(body), Err: jsonErr},
- )
+ fillErr := oauth.fillToken(body, currentTime, reqURL)
+ if fillErr != nil {
+ return types.NewWrappedError(errorMessage, fillErr)
}
-
- tokenStructure.ExpiredTimestamp = currentTime.Add(
- time.Second * time.Duration(tokenStructure.Expires),
- )
- oauth.Token = tokenStructure
return nil
}
@@ -506,41 +530,6 @@ func (oauth *OAuth) Cancel() {
}
}
-// EnsureTokens makes sure the OAuth tokens are still valid
-// if this cannot be guaranteed, it returns an error.
-func (oauth *OAuth) EnsureTokens() error {
- errorMessage := "failed ensuring OAuth tokens"
- // Access Token or Refresh Tokens empty, we can not ensure the tokens
- if oauth.Token.Access == "" && oauth.Token.Refresh == "" {
- return types.NewWrappedError(
- errorMessage,
- &OAuthTokensInvalidError{Cause: "tokens are empty"},
- )
- }
-
- // We have tokens...
- // The tokens are not expired yet
- // So they should be valid, re-login not needed
- if !oauth.isTokensExpired() {
- return nil
- }
-
- // Otherwise try to refresh them and return if successful
- refreshErr := oauth.tokensWithRefresh()
- // We have obtained new tokens with refresh
- if refreshErr != nil {
- // We have failed to ensure the tokens due to refresh not working
- return types.NewWrappedError(
- errorMessage,
- &OAuthTokensInvalidError{
- Cause: fmt.Sprintf("tokens failed refresh with error: %v", refreshErr),
- },
- )
- }
-
- return nil
-}
-
type OAuthCancelledCallbackError struct{}
func (e *OAuthCancelledCallbackError) Error() string {
diff --git a/internal/oauth/token.go b/internal/oauth/token.go
new file mode 100644
index 0000000..8ceb9a8
--- /dev/null
+++ b/internal/oauth/token.go
@@ -0,0 +1,37 @@
+package oauth
+
+import "time"
+
+// OAuthTokenResponse defines the OAuth response from the server that includes the tokens.
+type OAuthTokenResponse struct {
+ // Access is the access token returned by the server
+ Access string `json:"access_token"`
+
+ // Refresh token is the refresh token returned by the server
+ Refresh string `json:"refresh_token"`
+
+ // Type indicates which type of tokens we have
+ Type string `json:"token_type"`
+
+ // Expires is the expires time returned by the server
+ Expires int64 `json:"expires_in"`
+
+}
+
+// OAuthToken is a structure that contains our access and refresh tokens and a timestamp when they expire.
+type OAuthToken struct {
+ // Access is the access token returned by the server
+ access string
+
+ // Refresh token is the refresh token returned by the server
+ refresh string
+
+ // ExpiredTimestamp is the Expires field but converted to a Go timestamp
+ expiredTimestamp time.Time
+}
+
+// Expired checks if the access token is expired.
+func (tokens *OAuthToken) Expired() bool {
+ currentTime := time.Now()
+ return !currentTime.Before(tokens.expiredTimestamp)
+}
diff --git a/internal/server/api.go b/internal/server/api.go
index eb55bd8..65aadca 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -64,13 +64,13 @@ func apiAuthorized(
url.Path = path.Join(url.Path, endpoint)
// Make sure the tokens are valid, this will return an error if re-login is needed
- oauthErr := EnsureTokens(server)
- if oauthErr != nil {
- return nil, nil, types.NewWrappedError(errorMessage, oauthErr)
+ token, tokenErr := HeaderToken(server)
+ if tokenErr != nil {
+ return nil, nil, types.NewWrappedError(errorMessage, tokenErr)
}
headerKey := "Authorization"
- headerValue := fmt.Sprintf("Bearer %s", HeaderToken(server))
+ headerValue := fmt.Sprintf("Bearer %s", token)
if opts.Headers != nil {
opts.Headers.Add(headerKey, headerValue)
} else {
diff --git a/internal/server/common.go b/internal/server/common.go
index 351b3af..7f6599a 100644
--- a/internal/server/common.go
+++ b/internal/server/common.go
@@ -258,28 +258,25 @@ func OAuthExchange(server Server) error {
return server.OAuth().Exchange()
}
-func HeaderToken(server Server) string {
- return server.OAuth().Token.Access
+func HeaderToken(server Server) (string, error) {
+ token, tokenErr := server.OAuth().AccessToken()
+ if tokenErr != nil {
+ return "", types.NewWrappedError("failed getting server token for HTTP Header", tokenErr)
+ }
+ return token, nil
}
func MarkTokenExpired(server Server) {
- server.OAuth().Token.ExpiredTimestamp = time.Now()
+ server.OAuth().SetTokenExpired()
}
func MarkTokensForRenew(server Server) {
- server.OAuth().Token = oauth.OAuthToken{}
-}
-
-func EnsureTokens(server Server) error {
- ensureErr := server.OAuth().EnsureTokens()
- if ensureErr != nil {
- return types.NewWrappedError("failed ensuring server tokens", ensureErr)
- }
- return nil
+ server.OAuth().SetTokenRenew()
}
func NeedsRelogin(server Server) bool {
- return EnsureTokens(server) != nil
+ _, tokenErr := HeaderToken(server)
+ return tokenErr != nil
}
func CancelOAuth(server Server) {