summaryrefslogtreecommitdiff
path: root/client
diff options
context:
space:
mode:
Diffstat (limited to 'client')
-rw-r--r--client/client.go108
1 files changed, 101 insertions, 7 deletions
diff --git a/client/client.go b/client/client.go
index 813f6dc..28080b0 100644
--- a/client/client.go
+++ b/client/client.go
@@ -5,6 +5,7 @@ import (
"context"
"fmt"
"strings"
+ "time"
"github.com/eduvpn/eduvpn-common/internal/config"
"github.com/eduvpn/eduvpn-common/internal/discovery"
@@ -109,7 +110,59 @@ type Client struct {
Debug bool `json:"-"`
// The Failover monitor for the current VPN connection
- Failover *failover.DroppedConMon
+ Failover *failover.DroppedConMon `json:"-"`
+
+ // TokenSetter sets the tokens in the client
+ TokenSetter func(srv srvtypes.Current, tok srvtypes.Tokens) `json:"-"`
+
+ // TokenGetter gets the tokens from the client
+ TokenGetter func(srv srvtypes.Current) *srvtypes.Tokens `json:"-"`
+}
+
+func (c *Client) updateTokens(srv server.Server) error {
+ if c.TokenGetter == nil {
+ return errors.New("no tokken getter defined")
+ }
+ pSrv, err := c.pubCurrentServer(srv)
+ if err != nil {
+ return err
+ }
+ // shouldn't happen
+ if pSrv == nil {
+ return errors.New("public server is nil when getting tokens")
+ }
+ tokens := c.TokenGetter(*pSrv)
+ if tokens == nil {
+ return errors.New("client returned nil for tokens")
+ }
+
+ server.UpdateTokens(srv, oauth.Token{
+ Access: tokens.Access,
+ Refresh: tokens.Refresh,
+ ExpiredTimestamp: time.Unix(tokens.Expires, 0),
+ })
+
+ return nil
+}
+
+func (c *Client) forwardTokens(srv server.Server) error {
+ if c.TokenSetter == nil {
+ return errors.New("no token setter defined")
+ }
+ pSrv, err := c.pubCurrentServer(srv)
+ if err != nil {
+ return err
+ }
+ if pSrv == nil {
+ return errors.New("public server is nil when updating tokens")
+ }
+ o := srv.OAuth()
+ if o == nil {
+ return errors.New("oauth was nil when forwarding tokens")
+ }
+ t := o.Token()
+ c.TokenSetter(*pSrv, t.Public())
+ return nil
}
// New creates a new client with the following parameters:
@@ -329,7 +382,18 @@ func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool)
// oauth
// TODO: This should be ck.Context()
// But needsrelogin needs a rewrite to support this properly
+
+ // first make sure we get the most up to date tokens from the client
+ err := c.updateTokens(srv)
+ if err != nil {
+ log.Logger.Debugf("failed to get tokens from client: %v", err)
+ }
if server.NeedsRelogin(context.Background(), srv) || forceauth {
+ // mark organizations as expired if the server is a secure internet server
+ b, berr := srv.Base()
+ if berr == nil && b.Type == srvtypes.TypeSecureInternet {
+ c.Discovery.MarkOrganizationsExpired()
+ }
err := c.loginCallback(ck, srv)
if err != nil {
return err
@@ -421,6 +485,10 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
case srvtypes.TypeSecureInternet:
dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier)
if err != nil {
+ // We mark the organizations as expired because we got an error
+ // Note that in the docs it states that it only should happen when the Org ID doesn't exist
+ // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found
+ c.Discovery.MarkOrganizationsExpired()
return err
}
srv, err = c.Servers.AddSecureInternet(ck.Context(), dOrg, dSrv)
@@ -442,7 +510,15 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
}
// callbacks
- return c.callbacks(ck, srv, false)
+ err = c.callbacks(ck, srv, false)
+ if err != nil {
+ return err
+ }
+ terr := c.forwardTokens(srv)
+ if terr != nil {
+ log.Logger.Debugf("failed to forward tokens after adding: %v", terr)
+ }
+ return nil
}
func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool) (cfg *srvtypes.Configuration, err error) {
@@ -529,6 +605,14 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.
cfg, err = c.config(ck, srv, pTCP, true)
}
+ // tokens might be updated, forward them
+ defer func() {
+ terr := c.forwardTokens(srv)
+ if terr != nil {
+ log.Logger.Debugf("failed to forward tokens after get config: %v", terr)
+ }
+ }()
+
// still an error, return nil with the error
if err != nil {
return nil, err
@@ -700,15 +784,18 @@ func (c *Client) Cleanup(ck *cookie.Cookie) (err error) {
if err != nil {
return err
}
- // TODO: Support cookie context here
- // if server.NeedsRelogin(context.Background(), srv) {
- // // TODO: ask client for tokens
- // }
+ err = c.updateTokens(srv)
+ if err != nil {
+ log.Logger.Debugf("failed to update tokens for disconnect: %v", err)
+ }
err = server.Disconnect(ck.Context(), srv)
if err != nil {
return err
}
- // TODO: Set tokens with callback
+ err = c.forwardTokens(srv)
+ if err != nil {
+ log.Logger.Debugf("failed to forward tokens after disconnect: %v", err)
+ }
return nil
}
@@ -740,6 +827,13 @@ func (c *Client) RenewSession(ck *cookie.Cookie) (err error) {
c.FSM.GoTransition(StateLoadingServer)
c.FSM.GoTransition(StateChosenServer)
}
+ // update tokens in the end
+ defer func() {
+ terr := c.forwardTokens(srv)
+ if terr != nil {
+ log.Logger.Debugf("failed to forward tokens after renew: %v", terr)
+ }
+ }()
// TODO: Maybe this can be deleted because we force auth now
server.MarkTokensForRenew(srv)
// run the callbacks by forcing auth