diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/oauth/oauth.go | 24 | ||||
| -rw-r--r-- | internal/oauth/token.go | 74 | ||||
| -rw-r--r-- | internal/server/server.go | 53 |
3 files changed, 108 insertions, 43 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index ce86337..6d21c82 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -196,10 +196,19 @@ func (oauth *OAuth) SetTokenExpired() { // SetTokenRenew sets the tokens for renewal by completely clearing the structure. func (oauth *OAuth) SetTokenRenew() { if oauth.token != nil { - oauth.token.Clear() + oauth.token.Update(Token{}) } } +func (oauth *OAuth) Token() Token { + t := Token{} + if oauth.token != nil { + t = oauth.token.Get() + } + + return t +} + // tokensWithAuthCode gets the access and refresh tokens using the authorization code // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 @@ -239,10 +248,17 @@ func (oauth *OAuth) tokensWithAuthCode(authCode string) error { return errors.New("No token response after authorization code") } - oauth.token.Update(*tr, now) + oauth.token.UpdateResponse(*tr, now) return nil } +func (oauth *OAuth) UpdateTokens(t Token) { + if oauth.token == nil { + oauth.token = &tokenLock{t: &tokenRefresher{Refresher: oauth.refreshResponse}} + } + oauth.token.Update(t) +} + // refreshResponse gets the refresh token response with a refresh token // This response contains the access and refresh tokens, together with a timestamp // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 @@ -420,8 +436,8 @@ func (oauth *OAuth) AuthURL(name string, postProcessAuth func(string) string) (s return "", errors.WrapPrefix(err, "genState error", 0) } - // Fill the oauth tokens - oauth.token = &tokenLock{t: &token{Refresher: oauth.refreshResponse}} + // Re-initialize the token structure + oauth.UpdateTokens(Token{}) // Fill the struct with the necessary fields filled for the next call to getting the HTTP client oauth.session = exchangeSession{ diff --git a/internal/oauth/token.go b/internal/oauth/token.go index 855677c..4ed8f43 100644 --- a/internal/oauth/token.go +++ b/internal/oauth/token.go @@ -23,17 +23,23 @@ type TokenResponse struct { Expires int64 `json:"expires_in"` } -// token is a structure that contains our access and refresh tokens and a timestamp when they expire. -type token struct { - // Access is the access token returned by the server - access string +// The public type that can be passed to an update function +// It contains our access and refresh tokens with a timestamp +type Token struct { + // Access is the Access token returned by the server + Access string - // Refresh token is the refresh token returned by the server - refresh string + // Refresh token is the Refresh token returned by the server + Refresh string // ExpiredTimestamp is the Expires field but converted to a Go timestamp - expiredTimestamp time.Time + ExpiredTimestamp time.Time +} +// tokenRefresher is a structure that contains our access and refresh tokens and a timestamp when they expire. +// Additionally, it contains the refresher to get new tokens +type tokenRefresher struct { + Token // Refresher is the function that refreshes the token Refresher func(string) (*TokenResponse, time.Time, error) } @@ -44,7 +50,8 @@ type tokenLock struct { mu sync.Mutex // The token fields protected by the lock - t *token + // This token struct contains a refresher + t *tokenRefresher } // Access gets the OAuth access token used for contacting the server API @@ -57,17 +64,17 @@ func (l *tokenLock) Access() (string, error) { // The tokens are not expired yet // So they should be valid, re-login not neede if !l.expired() { - return l.t.access, nil + return l.t.Access, nil } // Check if refresh is even possible by doing a simple check if the refresh token is empty // This is not needed but reduces API calls to the server - if l.t.refresh == "" { + if l.t.Refresh == "" { return "", errors.Wrap(&TokensInvalidError{Cause: "no refresh token is present"}, 0) } // Otherwise refresh and then later return the access token if we are successful - tr, s, err := l.t.Refresher(l.t.refresh) + tr, s, err := l.t.Refresher(l.t.Refresh) if err != nil { // We have failed to ensure the tokens due to refresh not working return "", errors.Wrap( @@ -76,37 +83,50 @@ func (l *tokenLock) Access() (string, error) { if tr == nil { return "", errors.New("No token response after refreshing") } - l.updateInternal(*tr, s) - return l.t.access, nil + r := *tr + e := s.Add(time.Second * time.Duration(r.Expires)) + t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e} + l.updateInternal(t) + return l.t.Access, nil } -// Clear completely clears the token structure -// This is useful for forcing re-authorization -func (l *tokenLock) Clear() { +// UpdateResponse updates the structure using the server response and locks +func (l *tokenLock) UpdateResponse(r TokenResponse, s time.Time) { l.mu.Lock() - l.t = &token{} + e := s.Add(time.Second * time.Duration(r.Expires)) + t := Token{Access: r.Access, Refresh: r.Refresh, ExpiredTimestamp: e} + l.updateInternal(t) l.mu.Unlock() } -// updateInternal updates the structure using the response without locking -func (l *tokenLock) updateInternal(r TokenResponse, s time.Time) { - l.t.access = r.Access - l.t.refresh = r.Refresh - l.t.expiredTimestamp = s.Add(time.Second * time.Duration(r.Expires)) +// updateInternal updates the token structure internally but does not lock +func (l *tokenLock) updateInternal(r Token) { + l.t.Access = r.Access + l.t.Refresh = r.Refresh + l.t.ExpiredTimestamp = r.ExpiredTimestamp } -// Update updates the structure usign the response and locks -func (l *tokenLock) Update(r TokenResponse, s time.Time) { +// Update updates the token structure using the internal function but locks +func (l *tokenLock) Update(r Token) { l.mu.Lock() - l.updateInternal(r, s) + l.updateInternal(r) l.mu.Unlock() } + +// Get gets the tokens into a public struct +func (l *tokenLock) Get() Token { + // TODO: Check nil? + l.mu.Lock() + defer l.mu.Unlock() + return l.t.Token +} + // SetExpired overrides the timestamp to the current time // This marks the tokens as expired func (l *tokenLock) SetExpired() { l.mu.Lock() - l.t.expiredTimestamp = time.Now() + l.t.ExpiredTimestamp = time.Now() l.mu.Unlock() } @@ -114,5 +134,5 @@ func (l *tokenLock) SetExpired() { // This is only called internally and thus does not lock func (l *tokenLock) expired() bool { now := time.Now() - return !now.Before(l.t.expiredTimestamp) + return !now.Before(l.t.ExpiredTimestamp) } diff --git a/internal/server/server.go b/internal/server/server.go index 78f6472..1585264 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -77,6 +77,10 @@ func ShouldRenewButton(srv Server) bool { return true } +func UpdateTokens(srv Server, t oauth.Token) { + srv.OAuth().UpdateTokens(t) +} + func OAuthURL(srv Server, name string) (string, error) { return srv.OAuth().AuthURL(name, srv.TemplateAuth()) } @@ -134,22 +138,33 @@ func ValidProfiles(srv Server, wireguardSupport bool) (*ProfileInfo, error) { return &ps, nil } -func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string, string, error) { +type ConfigData struct { + // The configuration + Config string + + // The type of configuration + Type string + + // The tokens + Tokens oauth.Token +} + +func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (*ConfigData, error) { b, err := srv.Base() if err != nil { - return "", "", err + return nil, err } pID := b.Profiles.Current key, err := wireguard.GenerateKey() if err != nil { - return "", "", err + return nil, err } pub := key.PublicKey().String() cfg, proto, exp, err := APIConnectWireguard(srv, pID, pub, preferTCP, openVPNSupport) if err != nil { - return "", "", err + return nil, err } // Store start and end time @@ -164,25 +179,39 @@ func wireguardGetConfig(srv Server, preferTCP bool, openVPNSupport bool) (string cfg = wireguard.ConfigAddKey(cfg, key) } - return cfg, proto, nil + t := oauth.Token{} + + o := srv.OAuth() + if o != nil { + t = o.Token() + } + + return &ConfigData{Config: cfg, Type: proto, Tokens: t}, nil } -func openVPNGetConfig(srv Server, preferTCP bool) (string, string, error) { +func openVPNGetConfig(srv Server, preferTCP bool) (*ConfigData, error) { b, err := srv.Base() if err != nil { - return "", "", err + return nil, err } pid := b.Profiles.Current cfg, exp, err := APIConnectOpenVPN(srv, pid, preferTCP) if err != nil { - return "", "", err + return nil, err } // Store start and end time b.StartTime = time.Now() b.EndTime = exp - return cfg, "openvpn", nil + t := oauth.Token{} + + o := srv.OAuth() + if o != nil { + t = o.Token() + } + + return &ConfigData{Config: cfg, Type: "openvpn", Tokens: t}, nil } func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { @@ -237,10 +266,10 @@ func RefreshEndpoints(srv Server) error { return b.InitializeEndpoints() } -func Config(server Server, wireguardSupport bool, preferTCP bool) (string, string, error) { +func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) { p, err := CurrentProfile(server) if err != nil { - return "", "", err + return nil, err } ovpn := p.SupportsOpenVPN() @@ -266,7 +295,7 @@ func Config(server Server, wireguardSupport bool, preferTCP bool) (string, strin return openVPNGetConfig(server, preferTCP) // The config supports no available protocol because the profile only supports WireGuard but the client doesn't default: - return "", "", errors.Errorf("no supported protocol found") + return nil, errors.Errorf("no supported protocol found") } } |
