diff options
Diffstat (limited to 'client/client.go')
| -rw-r--r-- | client/client.go | 108 |
1 files changed, 101 insertions, 7 deletions
diff --git a/client/client.go b/client/client.go index 813f6dc..28080b0 100644 --- a/client/client.go +++ b/client/client.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/eduvpn/eduvpn-common/internal/config" "github.com/eduvpn/eduvpn-common/internal/discovery" @@ -109,7 +110,59 @@ type Client struct { Debug bool `json:"-"` // The Failover monitor for the current VPN connection - Failover *failover.DroppedConMon + Failover *failover.DroppedConMon `json:"-"` + + // TokenSetter sets the tokens in the client + TokenSetter func(srv srvtypes.Current, tok srvtypes.Tokens) `json:"-"` + + // TokenGetter gets the tokens from the client + TokenGetter func(srv srvtypes.Current) *srvtypes.Tokens `json:"-"` +} + +func (c *Client) updateTokens(srv server.Server) error { + if c.TokenGetter == nil { + return errors.New("no tokken getter defined") + } + pSrv, err := c.pubCurrentServer(srv) + if err != nil { + return err + } + // shouldn't happen + if pSrv == nil { + return errors.New("public server is nil when getting tokens") + } + tokens := c.TokenGetter(*pSrv) + if tokens == nil { + return errors.New("client returned nil for tokens") + } + + server.UpdateTokens(srv, oauth.Token{ + Access: tokens.Access, + Refresh: tokens.Refresh, + ExpiredTimestamp: time.Unix(tokens.Expires, 0), + }) + + return nil +} + +func (c *Client) forwardTokens(srv server.Server) error { + if c.TokenSetter == nil { + return errors.New("no token setter defined") + } + pSrv, err := c.pubCurrentServer(srv) + if err != nil { + return err + } + if pSrv == nil { + return errors.New("public server is nil when updating tokens") + } + o := srv.OAuth() + if o == nil { + return errors.New("oauth was nil when forwarding tokens") + } + t := o.Token() + c.TokenSetter(*pSrv, t.Public()) + return nil } // New creates a new client with the following parameters: @@ -329,7 +382,18 @@ func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool) // oauth // TODO: This should be ck.Context() // But needsrelogin needs a rewrite to support this properly + + // first make sure we get the most up to date tokens from the client + err := c.updateTokens(srv) + if err != nil { + log.Logger.Debugf("failed to get tokens from client: %v", err) + } if server.NeedsRelogin(context.Background(), srv) || forceauth { + // mark organizations as expired if the server is a secure internet server + b, berr := srv.Base() + if berr == nil && b.Type == srvtypes.TypeSecureInternet { + c.Discovery.MarkOrganizationsExpired() + } err := c.loginCallback(ck, srv) if err != nil { return err @@ -421,6 +485,10 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. case srvtypes.TypeSecureInternet: dOrg, dSrv, err := c.Discovery.SecureHomeArgs(identifier) if err != nil { + // We mark the organizations as expired because we got an error + // Note that in the docs it states that it only should happen when the Org ID doesn't exist + // However, this is nice as well because it also catches the error where the SecureInternetHome server is not found + c.Discovery.MarkOrganizationsExpired() return err } srv, err = c.Servers.AddSecureInternet(ck.Context(), dOrg, dSrv) @@ -442,7 +510,15 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. } // callbacks - return c.callbacks(ck, srv, false) + err = c.callbacks(ck, srv, false) + if err != nil { + return err + } + terr := c.forwardTokens(srv) + if terr != nil { + log.Logger.Debugf("failed to forward tokens after adding: %v", terr) + } + return nil } func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAuth bool) (cfg *srvtypes.Configuration, err error) { @@ -529,6 +605,14 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. cfg, err = c.config(ck, srv, pTCP, true) } + // tokens might be updated, forward them + defer func() { + terr := c.forwardTokens(srv) + if terr != nil { + log.Logger.Debugf("failed to forward tokens after get config: %v", terr) + } + }() + // still an error, return nil with the error if err != nil { return nil, err @@ -700,15 +784,18 @@ func (c *Client) Cleanup(ck *cookie.Cookie) (err error) { if err != nil { return err } - // TODO: Support cookie context here - // if server.NeedsRelogin(context.Background(), srv) { - // // TODO: ask client for tokens - // } + err = c.updateTokens(srv) + if err != nil { + log.Logger.Debugf("failed to update tokens for disconnect: %v", err) + } err = server.Disconnect(ck.Context(), srv) if err != nil { return err } - // TODO: Set tokens with callback + err = c.forwardTokens(srv) + if err != nil { + log.Logger.Debugf("failed to forward tokens after disconnect: %v", err) + } return nil } @@ -740,6 +827,13 @@ func (c *Client) RenewSession(ck *cookie.Cookie) (err error) { c.FSM.GoTransition(StateLoadingServer) c.FSM.GoTransition(StateChosenServer) } + // update tokens in the end + defer func() { + terr := c.forwardTokens(srv) + if terr != nil { + log.Logger.Debugf("failed to forward tokens after renew: %v", terr) + } + }() // TODO: Maybe this can be deleted because we force auth now server.MarkTokensForRenew(srv) // run the callbacks by forcing auth |
