summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-12-20 15:43:55 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-12-21 18:28:50 +0100
commit12838c19514459974cf0a71c42f1248b1cb9419c (patch)
treea4254d20bb7b0ef49a2fa6c12753eb4c5acb64d1 /internal
parent6981666c6d8f639a1ff9c09a3bc08769e19928af (diff)
Exports + OAuth + Server: Forward tokens to getting a config
Diffstat (limited to 'internal')
-rw-r--r--internal/oauth/oauth.go24
-rw-r--r--internal/oauth/token.go74
-rw-r--r--internal/server/server.go53
3 files changed, 108 insertions, 43 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index ce86337..6d21c82 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -196,10 +196,19 @@ func (oauth *OAuth) SetTokenExpired() {
// SetTokenRenew sets the tokens for renewal by completely clearing the structure.
func (oauth *OAuth) SetTokenRenew() {
if oauth.token != nil {
- oauth.token.Clear()
+ oauth.token.Update(Token{})
}
}
+func (oauth *OAuth) Token() Token {
+ t := Token{}
+ if oauth.token != nil {
+ t = oauth.token.Get()
+ }
+
+ return t
+}
+
// tokensWithAuthCode gets the access and refresh tokens using the authorization code
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2
@@ -239,10 +248,17 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error {
return errors.New("No token response after authorization code")
}
- oauth.token.Update(*tr, now)
+ oauth.token.UpdateResponse(*tr, now)
return nil
}
+func (oauth *OAuth) UpdateTokens(t Token) {
+ if oauth.token == nil {
+ oauth.token = &tokenLock{t: &tokenRefresher{Refresher: oauth.refreshResponse}}
+ }
+ oauth.token.Update(t)
+}
+
// refreshResponse gets the refresh token response with a refresh token
// This response contains the access and refresh tokens, together with a timestamp
// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4
@@ -420,8 +436,8 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s
return "", errors.WrapPrefix(err, "genState error", 0)
}
- // Fill the oauth tokens
- oauth.token = &tokenLock{t: &token{Refresher: oauth.refreshResponse}}
+ // Re-initialize the token structure
+ oauth.UpdateTokens(Token{})
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
oauth.session = exchangeSession{
diff --git a/internal/oauth/token.go b/internal/oauth/token.go
index 855677c..4ed8f43 100644
--- a/internal/oauth/token.go
+++ b/internal/oauth/token.go
@@ -23,17 +23,23 @@ type TokenResponse struct {
Expires int64 `json:"expires_in"`
}
-// token is a structure that contains our access and refresh tokens and a timestamp when they expire.
-type token struct {
- // Access is the access token returned by the server
- access string
+// The public type that can be passed to an update function
+// It contains our access and refresh tokens with a timestamp
+type Token struct {
+ // Access is the Access token returned by the server
+ Access string
- // Refresh token is the refresh token returned by the server
- refresh string
+ // Refresh token is the Refresh token returned by the server
+ Refresh string
// ExpiredTimestamp is the Expires field but converted to a Go timestamp
- expiredTimestamp time.Time
+ ExpiredTimestamp time.Time
+}
+// tokenRefresher is a structure that contains our access and refresh tokens and a timestamp when they expire.
+// Additionally, it contains the refresher to get new tokens
+type tokenRefresher struct {
+ Token
// Refresher is the function that refreshes the token
Refresher func(string) (*TokenResponse, time.Time, error)
}
@@ -44,7 +50,8 @@ type tokenLock struct {
mu sync.Mutex
// The token fields protected by the lock
- t *token
+ // This token struct contains a refresher
+ t *tokenRefresher
}
// Access gets the OAuth access token used for contacting the server API
@@ -57,17 +64,17 @@ func (l *tokenLock) Access() (string, error) {
// The tokens are not expired yet
// So they should be valid, re-login not neede
if !l.expired() {
- return l.t.access, nil
+ return l.t.Access, nil
}
// Check if refresh is even possible by doing a simple check if the refresh token is empty
// This is not needed but reduces API calls to the server
- if l.t.refresh == "" {
+ if l.t.Refresh == "" {
return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0)
}
// Otherwise refresh and then later return the access token if we are successful
- tr, s, err := l.t.Refresher(l.t.refresh)
+ tr, s, err := l.t.Refresher(l.t.Refresh)
if err != nil {
// We have failed to ensure the tokens due to refresh not working
return "", errors.Wrap(
@@ -76,37 +83,50 @@ func (l *tokenLock) Access() (string, error) {
if tr == nil {
return "", errors.New("No token response after refreshing")
}
- l.updateInternal(*tr, s)
- return l.t.access, nil
+ r := *tr
+ e := s.Add(time.Second * time.Duration(r.Expires))
+ t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e}
+ l.updateInternal(t)
+ return l.t.Access, nil
}
-// Clear completely clears the token structure
-// This is useful for forcing re-authorization
-func (l *tokenLock) Clear() {
+// UpdateResponse updates the structure using the server response and locks
+func (l *tokenLock) UpdateResponse(r TokenResponse, s time.Time) {
l.mu.Lock()
- l.t = &token{}
+ e := s.Add(time.Second * time.Duration(r.Expires))
+ t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e}
+ l.updateInternal(t)
l.mu.Unlock()
}
-// updateInternal updates the structure using the response without locking
-func (l *tokenLock) updateInternal(r TokenResponse, s time.Time) {
- l.t.access = r.Access
- l.t.refresh = r.Refresh
- l.t.expiredTimestamp = s.Add(time.Second * time.Duration(r.Expires))
+// updateInternal updates the token structure internally but does not lock
+func (l *tokenLock) updateInternal(r Token) {
+ l.t.Access = r.Access
+ l.t.Refresh = r.Refresh
+ l.t.ExpiredTimestamp = r.ExpiredTimestamp
}
-// Update updates the structure usign the response and locks
-func (l *tokenLock) Update(r TokenResponse, s time.Time) {
+// Update updates the token structure using the internal function but locks
+func (l *tokenLock) Update(r Token) {
l.mu.Lock()
- l.updateInternal(r, s)
+ l.updateInternal(r)
l.mu.Unlock()
}
+
+// Get gets the tokens into a public struct
+func (l *tokenLock) Get() Token {
+ // TODO: Check nil?
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ return l.t.Token
+}
+
// SetExpired overrides the timestamp to the current time
// This marks the tokens as expired
func (l *tokenLock) SetExpired() {
l.mu.Lock()
- l.t.expiredTimestamp = time.Now()
+ l.t.ExpiredTimestamp = time.Now()
l.mu.Unlock()
}
@@ -114,5 +134,5 @@ func (l *tokenLock) SetExpired() {
// This is only called internally and thus does not lock
func (l *tokenLock) expired() bool {
now := time.Now()
- return !now.Before(l.t.expiredTimestamp)
+ return !now.Before(l.t.ExpiredTimestamp)
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 78f6472..1585264 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -77,6 +77,10 @@ func ShouldRenewButton(srv Server) bool {
return true
}
+func UpdateTokens(srv Server, t oauth.Token) {
+ srv.OAuth().UpdateTokens(t)
+}
+
func OAuthURL(srv Server, name string) (string, error) {
return srv.OAuth().AuthURL(name, srv.TemplateAuth())
}
@@ -134,22 +138,33 @@ func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) {
return &ps, nil
}
-func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string, string, error) {
+type ConfigData struct {
+ // The configuration
+ Config string
+
+ // The type of configuration
+ Type string
+
+ // The tokens
+ Tokens oauth.Token
+}
+
+func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) {
b, err := srv.Base()
if err != nil {
- return "", "", err
+ return nil, err
}
pID := b.Profiles.Current
key, err := wireguard.GenerateKey()
if err != nil {
- return "", "", err
+ return nil, err
}
pub := key.PublicKey().String()
cfg, proto, exp, err := APIConnectWireguard(srv, pID, pub, preferTCP, openVPNSupport)
if err != nil {
- return "", "", err
+ return nil, err
}
// Store start and end time
@@ -164,25 +179,39 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string
cfg = wireguard.ConfigAddKey(cfg, key)
}
- return cfg, proto, nil
+ t := oauth.Token{}
+
+ o := srv.OAuth()
+ if o != nil {
+ t = o.Token()
+ }
+
+ return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil
}
-func openVPNGetConfig(srv Server, preferTCP bool) (string, string, error) {
+func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) {
b, err := srv.Base()
if err != nil {
- return "", "", err
+ return nil, err
}
pid := b.Profiles.Current
cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP)
if err != nil {
- return "", "", err
+ return nil, err
}
// Store start and end time
b.StartTime = time.Now()
b.EndTime = exp
- return cfg, "openvpn", nil
+ t := oauth.Token{}
+
+ o := srv.OAuth()
+ if o != nil {
+ t = o.Token()
+ }
+
+ return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil
}
func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) {
@@ -237,10 +266,10 @@ func RefreshEndpoints(srv Server) error {
return b.InitializeEndpoints()
}
-func Config(server Server, wireguardSupport bool, preferTCP bool) (string, string, error) {
+func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) {
p, err := CurrentProfile(server)
if err != nil {
- return "", "", err
+ return nil, err
}
ovpn := p.SupportsOpenVPN()
@@ -266,7 +295,7 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (string, strin
return openVPNGetConfig(server, preferTCP)
// The config supports no available protocol because the profile only supports WireGuard but the client doesn't
default:
- return "", "", errors.Errorf("no supported protocol found")
+ return nil, errors.Errorf("no supported protocol found")
}
}