summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2024-07-17 11:20:59 +0200
committerJeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com>2024-07-17 14:00:03 +0000
commit7276108171b1c6af417ce5ae86ef0221280932c1 (patch)
tree3decc148de89b9ae3b3b6d04de35fea71ffb4c3b
parent5362dd10bf13f8087a13a4ab5441efa04755fcc7 (diff)
Client + Server: Pass discovery manager and lock when needed
-rw-r--r--client/client.go49
-rw-r--r--client/discovery.go79
-rw-r--r--internal/discovery/manager.go88
-rw-r--r--internal/server/institute.go10
-rw-r--r--internal/server/secureinternet.go31
-rw-r--r--internal/server/servers.go10
6 files changed, 141 insertions, 126 deletions
diff --git a/client/client.go b/client/client.go
index 743c79a..a64b4a4 100644
--- a/client/client.go
+++ b/client/client.go
@@ -54,7 +54,7 @@ type Client struct {
mu sync.Mutex
- discoMan DiscoManager
+ discoMan *discovery.Manager
}
// GettingConfig is defined here to satisfy the server.Callbacks interface
@@ -155,7 +155,7 @@ func New(name string, version string, directory string, stateCallback func(FSMSt
// set the servers
c.Servers = server.NewServers(c.Name, c, c.cfg.V2)
- c.discoMan = DiscoManager{disco: c.cfg.Discovery()}
+ c.discoMan = discovery.NewManager(c.cfg.Discovery())
if !c.hasDiscovery() {
return c, nil
@@ -289,8 +289,10 @@ func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) {
}, nil
}
-func (c *Client) locationCallback(ck *cookie.Cookie, disco *discovery.Discovery, orgID string) error {
+func (c *Client) locationCallback(ck *cookie.Cookie, orgID string) error {
+ disco, release := c.discoMan.Discovery(false)
locs := disco.SecureLocationList()
+ release()
errChan := make(chan error)
go func() {
err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{
@@ -342,14 +344,10 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
// We add the server because we can then obtain it in other callback functions
previousState := c.FSM.Current
- var release func()
defer func() {
if err == nil {
c.TrySave()
}
- if release != nil {
- release()
- }
// If we must run callbacks, go to the previous state if we're not in it
if !ni && !c.FSM.InState(previousState) {
c.FSM.GoTransition(previousState) //nolint:errcheck
@@ -371,16 +369,14 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.
}
}
- var disco *discovery.Discovery
- disco, release = c.discoMan.Discovery(true)
switch _type {
case srvtypes.TypeInstituteAccess:
- err = c.Servers.AddInstitute(ck.Context(), disco, identifier, ot)
+ err = c.Servers.AddInstitute(ck.Context(), c.discoMan, identifier, ot)
if err != nil {
return i18nerr.Wrapf(err, "Failed to add an institute access server with URL: '%s'", identifier)
}
case srvtypes.TypeSecureInternet:
- err = c.Servers.AddSecure(ck.Context(), disco, identifier, ot)
+ err = c.Servers.AddSecure(ck.Context(), c.discoMan, identifier, ot)
if err != nil {
return i18nerr.Wrapf(err, "Failed to add a secure internet server with organisation ID: '%s'", identifier)
}
@@ -418,12 +414,8 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.
return nil, i18nerr.NewInternalf("Getting a non-custom server when the client does not use discovery is not supported, identifier: %s, type: %d", identifier, _type)
}
- var release func()
defer func() {
c.TrySave()
- if release != nil {
- release()
- }
if err == nil {
// it could be that we are not in getting config yet if we have just done authorization
c.FSM.GoTransition(StateGettingConfig) //nolint:errcheck
@@ -449,33 +441,35 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.
}
ctx := ck.Context()
- var disco *discovery.Discovery
- disco, release = c.discoMan.Discovery(true)
if _type != srvtypes.TypeCustom {
+ disco, release := c.discoMan.Discovery(true)
// make sure the servers are fetched fresh
_, _, dserverr := disco.Servers(ctx)
if dserverr != nil {
log.Logger.Warningf("failed to fetch server discovery when getting config: %v", dserverr)
}
+ release()
}
var srv *server.Server
switch _type {
case srvtypes.TypeInstituteAccess:
- srv, err = c.Servers.GetInstitute(ctx, identifier, disco, tok, startup)
+ srv, err = c.Servers.GetInstitute(ctx, identifier, c.discoMan, tok, startup)
case srvtypes.TypeSecureInternet:
+ disco, release := c.discoMan.Discovery(true)
// make sure the organizations are fetched if they need an update
_, _, dorgerr := disco.Organizations(ctx)
if dorgerr != nil {
log.Logger.Warningf("failed to fetch organization discovery when getting config: %v", dorgerr)
}
- srv, err = c.Servers.GetSecure(ctx, identifier, disco, tok, startup)
+ release()
+ srv, err = c.Servers.GetSecure(ctx, identifier, c.discoMan, tok, startup)
var cErr *discovery.ErrCountryNotFound
if errors.As(err, &cErr) {
- err = c.locationCallback(ck, disco, identifier)
+ err = c.locationCallback(ck, identifier)
if err == nil {
- srv, err = c.Servers.GetSecure(ctx, identifier, disco, tok, startup)
+ srv, err = c.Servers.GetSecure(ctx, identifier, c.discoMan, tok, startup)
}
}
case srvtypes.TypeCustom:
@@ -521,10 +515,7 @@ func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error
// CurrentServer gets the current server that is configured
func (c *Client) CurrentServer() (*srvtypes.Current, error) {
- // TODO: do clients call this during a write mutex?
- disco, release := c.discoMan.Discovery(false)
- defer release()
- curr, err := c.Servers.PublicCurrent(disco)
+ curr, err := c.Servers.PublicCurrent(c.discoMan)
if err != nil {
return nil, i18nerr.WrapInternal(err, "The current server could not be retrieved")
}
@@ -579,9 +570,7 @@ func (c *Client) Cleanup(ck *cookie.Cookie) error {
if err != nil {
return i18nerr.WrapInternal(err, "No OAuth tokens were found when cleaning up the connection")
}
- disco, release := c.discoMan.Discovery(true)
- defer release()
- auth, err := srv.ServerWithCallbacks(ck.Context(), disco, tok, true)
+ auth, err := srv.ServerWithCallbacks(ck.Context(), c.discoMan, tok, true)
if err != nil {
return i18nerr.WrapInternal(err, "The server was unable to be retrieved when cleaning up the connection")
}
@@ -632,9 +621,7 @@ func (c *Client) RenewSession(ck *cookie.Cookie) error {
previousState := c.FSM.Current
// getting a server with no tokens means re-authorize
- disco, release := c.discoMan.Discovery(true)
- defer release()
- _, err = srv.ServerWithCallbacks(ck.Context(), disco, nil, false)
+ _, err = srv.ServerWithCallbacks(ck.Context(), c.discoMan, nil, false)
if err != nil {
c.FSM.GoTransition(previousState) //nolint:errcheck
return i18nerr.WrapInternal(err, "The server was unable to be retrieved when renewing the session")
diff --git a/client/discovery.go b/client/discovery.go
index 2816d10..6cc76f6 100644
--- a/client/discovery.go
+++ b/client/discovery.go
@@ -4,11 +4,8 @@ import (
"context"
"sort"
"strings"
- "sync"
"github.com/eduvpn/eduvpn-common/i18nerr"
- "github.com/eduvpn/eduvpn-common/internal/discovery"
- "github.com/eduvpn/eduvpn-common/internal/log"
"github.com/eduvpn/eduvpn-common/types/cookie"
discotypes "github.com/eduvpn/eduvpn-common/types/discovery"
)
@@ -138,79 +135,3 @@ func (c *Client) DiscoveryStartup(cb func()) error {
c.discoMan.Startup(context.Background(), fcb)
return nil
}
-
-type DiscoManager struct {
- disco *discovery.Discovery
-
- cancel context.CancelFunc
- mu sync.RWMutex
- wait sync.WaitGroup
-}
-
-func (m *DiscoManager) lock(write bool) {
- log.Logger.Debugf("Locking write: %v", write)
- if write {
- m.mu.Lock()
- return
- }
- m.mu.RLock()
-}
-
-func (m *DiscoManager) unlock(write bool) {
- log.Logger.Debugf("Unlocking write: %v", write)
- if write {
- m.mu.Unlock()
- return
- }
- m.mu.RUnlock()
-}
-
-func (m *DiscoManager) Discovery(write bool) (*discovery.Discovery, func()) {
- log.Logger.Debugf("Requesting discovery write: %v", write)
- if write {
- m.wait.Wait()
- }
- m.lock(write)
- return m.disco, func() {
- m.unlock(write)
- }
-}
-
-func (m *DiscoManager) Cancel() {
- if m.cancel != nil {
- m.cancel()
- }
- m.wait.Wait()
-}
-
-func (m *DiscoManager) Startup(ctx context.Context, cb func()) {
- ctx, cancel := context.WithCancel(ctx)
- m.cancel = cancel
- m.wait.Add(1)
- go func() {
- m.lock(false)
- discoCopy, err := m.disco.Copy()
- if err != nil {
- log.Logger.Warningf("internal error, failed to clone discovery, %v", err)
- return
- }
- m.unlock(false)
- // we already log the warning
- discoCopy.Servers(ctx) //nolint:errcheck
-
- m.lock(true)
- m.disco.UpdateServers(discoCopy)
- m.unlock(true)
- m.wait.Done()
-
- select {
- case <-ctx.Done():
- return
- default:
- if cb == nil {
- return
- }
- cb()
- }
- }()
-}
diff --git a/internal/discovery/manager.go b/internal/discovery/manager.go
new file mode 100644
index 0000000..b84e770
--- /dev/null
+++ b/internal/discovery/manager.go
@@ -0,0 +1,88 @@
+package discovery
+
+import (
+ "context"
+ "sync"
+
+ "github.com/eduvpn/eduvpn-common/internal/log"
+)
+
+type Manager struct {
+ disco *Discovery
+
+ cancel context.CancelFunc
+ mu sync.RWMutex
+ wait sync.WaitGroup
+}
+
+func NewManager(disco *Discovery) *Manager {
+ return &Manager{disco: disco}
+}
+
+func (m *Manager) lock(write bool) {
+ log.Logger.Debugf("Locking write: %v", write)
+ if write {
+ m.mu.Lock()
+ return
+ }
+ m.mu.RLock()
+}
+
+func (m *Manager) unlock(write bool) {
+ log.Logger.Debugf("Unlocking write: %v", write)
+ if write {
+ m.mu.Unlock()
+ return
+ }
+ m.mu.RUnlock()
+}
+
+func (m *Manager) Discovery(write bool) (*Discovery, func()) {
+ log.Logger.Debugf("Requesting discovery write: %v", write)
+ if write {
+ m.wait.Wait()
+ }
+ m.lock(write)
+ return m.disco, func() {
+ m.unlock(write)
+ }
+}
+
+func (m *Manager) Cancel() {
+ if m.cancel != nil {
+ m.cancel()
+ }
+ m.wait.Wait()
+}
+
+func (m *Manager) Startup(ctx context.Context, cb func()) {
+ ctx, cancel := context.WithCancel(ctx)
+ m.cancel = cancel
+ m.wait.Add(1)
+ go func() {
+ m.lock(false)
+ discoCopy, err := m.disco.Copy()
+ if err != nil {
+ log.Logger.Warningf("internal error, failed to clone discovery, %v", err)
+ return
+ }
+ m.unlock(false)
+ // we already log the warning
+ discoCopy.Servers(ctx) //nolint:errcheck
+
+ m.lock(true)
+ m.disco.UpdateServers(discoCopy)
+ m.unlock(true)
+ m.wait.Done()
+
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ if cb == nil {
+ return
+ }
+ cb()
+ }
+ }()
+}
diff --git a/internal/server/institute.go b/internal/server/institute.go
index caae004..4ddf939 100644
--- a/internal/server/institute.go
+++ b/internal/server/institute.go
@@ -17,12 +17,15 @@ import (
// `disco` are the discovery servers
// `id` is the identifier for the server, the base url
// `ot` specifies specifies the start time OAuth was already triggered
-func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, id string, ot *int64) error {
+func (s *Servers) AddInstitute(ctx context.Context, discom *discovery.Manager, id string, ot *int64) error {
// This is basically done to double check if the server is part of the institute access section of disco
+ disco, release := discom.Discovery(false)
dsrv, err := disco.ServerByURL(id, "institute_access")
if err != nil {
+ release()
return err
}
+ release()
sd := api.ServerData{
ID: dsrv.BaseURL,
@@ -67,12 +70,15 @@ func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery,
// `disco` are the discovery servers
// `tok` are the tokens such that we do not have to trigger auth
// `disableAuth` is true when auth should never be triggered
-func (s *Servers) GetInstitute(ctx context.Context, id string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) {
+func (s *Servers) GetInstitute(ctx context.Context, id string, discom *discovery.Manager, tok *eduoauth.Token, disableAuth bool) (*Server, error) {
+ disco, release := discom.Discovery(false)
// This is basically done to double check if the server is part of the institute access section of disco
dsrv, err := disco.ServerByURL(id, "institute_access")
if err != nil {
+ release()
return nil, err
}
+ release()
// Get the server from the config
_, err = s.config.GetServer(dsrv.BaseURL, server.TypeInstituteAccess)
diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go
index 990ceb3..9b4b873 100644
--- a/internal/server/secureinternet.go
+++ b/internal/server/secureinternet.go
@@ -19,14 +19,17 @@ import (
// `disco` are the discovery servers
// `orgID` is the organiztaion ID
// `ot` specifies specifies the start time OAuth was already triggered
-func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, orgID string, ot *int64) error {
+func (s *Servers) AddSecure(ctx context.Context, discom *discovery.Manager, orgID string, ot *int64) error {
if s.config.HasSecureInternet() {
return errors.New("a secure internet server already exists")
}
+ disco, release := discom.Discovery(false)
dorg, dsrv, err := disco.SecureHomeArgs(orgID)
if err != nil {
+ release()
return err
}
+ release()
sd := api.ServerData{
ID: dorg.OrgID,
@@ -34,11 +37,13 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org
BaseWK: dsrv.BaseURL,
BaseAuthWK: dsrv.BaseURL,
ProcessAuth: func(ctx context.Context, url string) (string, error) {
+ newd, release := discom.Discovery(true)
+ defer release()
// the only thing we can do is log warn
// this is already done in the functions
- disco.Servers(ctx) //nolint:errcheck
- disco.Organizations(ctx) //nolint:errcheck
- updorg, updsrv, err := disco.SecureHomeArgs(orgID)
+ newd.Servers(ctx) //nolint:errcheck
+ newd.Organizations(ctx) //nolint:errcheck
+ updorg, updsrv, err := newd.SecureHomeArgs(orgID)
if err != nil {
return "", err
}
@@ -84,21 +89,25 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org
// `disco` are the discovery servers
// `tok` are the tokens such that the server can be found without triggering auth
// `disableAuth` is set to true when authorization should not be triggered
-func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) {
+func (s *Servers) GetSecure(ctx context.Context, orgID string, discom *discovery.Manager, tok *eduoauth.Token, disableAuth bool) (*Server, error) {
srv, err := s.config.GetServer(orgID, server.TypeSecureInternet)
if err != nil {
return nil, err
}
+ disco, release := discom.Discovery(false)
dorg, dhome, err := disco.SecureHomeArgs(orgID)
if err != nil {
+ release()
return nil, err
}
dloc, err := disco.ServerByCountryCode(srv.CountryCode)
if err != nil {
+ release()
return nil, err
}
+ release()
sd := api.ServerData{
ID: dorg.OrgID,
@@ -106,13 +115,15 @@ func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery.
BaseWK: dloc.BaseURL,
BaseAuthWK: dhome.BaseURL,
ProcessAuth: func(ctx context.Context, url string) (string, error) {
+ newd, release := discom.Discovery(true)
+ defer release()
// the only thing we can do is log warn
// this is already done in the functions
- disco.MarkServersExpired()
- disco.Servers(ctx) //nolint:errcheck
- disco.MarkOrganizationsExpired()
- disco.Organizations(ctx) //nolint:errcheck
- updorg, updsrv, err := disco.SecureHomeArgs(orgID)
+ newd.MarkServersExpired()
+ newd.Servers(ctx) //nolint:errcheck
+ newd.MarkOrganizationsExpired()
+ newd.Organizations(ctx) //nolint:errcheck
+ updorg, updsrv, err := newd.SecureHomeArgs(orgID)
if err != nil {
return "", err
}
diff --git a/internal/server/servers.go b/internal/server/servers.go
index 8285801..d06c37d 100644
--- a/internal/server/servers.go
+++ b/internal/server/servers.go
@@ -54,12 +54,12 @@ type CurrentServer struct {
}
// ServerWithCallbacks gets the current server as a server struct and triggers callbacks as needed
-func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, disco *discovery.Discovery, tokens *eduoauth.Token, disableAuth bool) (*Server, error) {
+func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, discom *discovery.Manager, tokens *eduoauth.Token, disableAuth bool) (*Server, error) {
switch cs.Key.T {
case srvtypes.TypeInstituteAccess:
- return cs.srvs.GetInstitute(ctx, cs.Key.ID, disco, tokens, disableAuth)
+ return cs.srvs.GetInstitute(ctx, cs.Key.ID, discom, tokens, disableAuth)
case srvtypes.TypeSecureInternet:
- return cs.srvs.GetSecure(ctx, cs.Key.ID, disco, tokens, disableAuth)
+ return cs.srvs.GetSecure(ctx, cs.Key.ID, discom, tokens, disableAuth)
case srvtypes.TypeCustom:
return cs.srvs.GetCustom(ctx, cs.Key.ID, tokens, disableAuth)
default:
@@ -89,7 +89,9 @@ func (s *Servers) CurrentServer() (*CurrentServer, error) {
}
// PublicCurrent gets the current server into a type that we can return to the client
-func (s *Servers) PublicCurrent(disco *discovery.Discovery) (*srvtypes.Current, error) {
+func (s *Servers) PublicCurrent(discom *discovery.Manager) (*srvtypes.Current, error) {
+ disco, release := discom.Discovery(false)
+ defer release()
return s.config.PublicCurrent(disco)
}