diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2024-07-17 11:20:59 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2024-07-17 14:00:03 +0000 |
| commit | 7276108171b1c6af417ce5ae86ef0221280932c1 (patch) | |
| tree | 3decc148de89b9ae3b3b6d04de35fea71ffb4c3b | |
| parent | 5362dd10bf13f8087a13a4ab5441efa04755fcc7 (diff) | |
Client + Server: Pass discovery manager and lock when needed
| -rw-r--r-- | client/client.go | 49 | ||||
| -rw-r--r-- | client/discovery.go | 79 | ||||
| -rw-r--r-- | internal/discovery/manager.go | 88 | ||||
| -rw-r--r-- | internal/server/institute.go | 10 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 31 | ||||
| -rw-r--r-- | internal/server/servers.go | 10 |
6 files changed, 141 insertions, 126 deletions
diff --git a/client/client.go b/client/client.go index 743c79a..a64b4a4 100644 --- a/client/client.go +++ b/client/client.go @@ -54,7 +54,7 @@ type Client struct { mu sync.Mutex - discoMan DiscoManager + discoMan *discovery.Manager } // GettingConfig is defined here to satisfy the server.Callbacks interface @@ -155,7 +155,7 @@ func New(name string, version string, directory string, stateCallback func(FSMSt // set the servers c.Servers = server.NewServers(c.Name, c, c.cfg.V2) - c.discoMan = DiscoManager{disco: c.cfg.Discovery()} + c.discoMan = discovery.NewManager(c.cfg.Discovery()) if !c.hasDiscovery() { return c, nil @@ -289,8 +289,10 @@ func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { }, nil } -func (c *Client) locationCallback(ck *cookie.Cookie, disco *discovery.Discovery, orgID string) error { +func (c *Client) locationCallback(ck *cookie.Cookie, orgID string) error { + disco, release := c.discoMan.Discovery(false) locs := disco.SecureLocationList() + release() errChan := make(chan error) go func() { err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{ @@ -342,14 +344,10 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. // We add the server because we can then obtain it in other callback functions previousState := c.FSM.Current - var release func() defer func() { if err == nil { c.TrySave() } - if release != nil { - release() - } // If we must run callbacks, go to the previous state if we're not in it if !ni && !c.FSM.InState(previousState) { c.FSM.GoTransition(previousState) //nolint:errcheck @@ -371,16 +369,14 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. } } - var disco *discovery.Discovery - disco, release = c.discoMan.Discovery(true) switch _type { case srvtypes.TypeInstituteAccess: - err = c.Servers.AddInstitute(ck.Context(), disco, identifier, ot) + err = c.Servers.AddInstitute(ck.Context(), c.discoMan, identifier, ot) if err != nil { return i18nerr.Wrapf(err, "Failed to add an institute access server with URL: '%s'", identifier) } case srvtypes.TypeSecureInternet: - err = c.Servers.AddSecure(ck.Context(), disco, identifier, ot) + err = c.Servers.AddSecure(ck.Context(), c.discoMan, identifier, ot) if err != nil { return i18nerr.Wrapf(err, "Failed to add a secure internet server with organisation ID: '%s'", identifier) } @@ -418,12 +414,8 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. return nil, i18nerr.NewInternalf("Getting a non-custom server when the client does not use discovery is not supported, identifier: %s, type: %d", identifier, _type) } - var release func() defer func() { c.TrySave() - if release != nil { - release() - } if err == nil { // it could be that we are not in getting config yet if we have just done authorization c.FSM.GoTransition(StateGettingConfig) //nolint:errcheck @@ -449,33 +441,35 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. } ctx := ck.Context() - var disco *discovery.Discovery - disco, release = c.discoMan.Discovery(true) if _type != srvtypes.TypeCustom { + disco, release := c.discoMan.Discovery(true) // make sure the servers are fetched fresh _, _, dserverr := disco.Servers(ctx) if dserverr != nil { log.Logger.Warningf("failed to fetch server discovery when getting config: %v", dserverr) } + release() } var srv *server.Server switch _type { case srvtypes.TypeInstituteAccess: - srv, err = c.Servers.GetInstitute(ctx, identifier, disco, tok, startup) + srv, err = c.Servers.GetInstitute(ctx, identifier, c.discoMan, tok, startup) case srvtypes.TypeSecureInternet: + disco, release := c.discoMan.Discovery(true) // make sure the organizations are fetched if they need an update _, _, dorgerr := disco.Organizations(ctx) if dorgerr != nil { log.Logger.Warningf("failed to fetch organization discovery when getting config: %v", dorgerr) } - srv, err = c.Servers.GetSecure(ctx, identifier, disco, tok, startup) + release() + srv, err = c.Servers.GetSecure(ctx, identifier, c.discoMan, tok, startup) var cErr *discovery.ErrCountryNotFound if errors.As(err, &cErr) { - err = c.locationCallback(ck, disco, identifier) + err = c.locationCallback(ck, identifier) if err == nil { - srv, err = c.Servers.GetSecure(ctx, identifier, disco, tok, startup) + srv, err = c.Servers.GetSecure(ctx, identifier, c.discoMan, tok, startup) } } case srvtypes.TypeCustom: @@ -521,10 +515,7 @@ func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error // CurrentServer gets the current server that is configured func (c *Client) CurrentServer() (*srvtypes.Current, error) { - // TODO: do clients call this during a write mutex? - disco, release := c.discoMan.Discovery(false) - defer release() - curr, err := c.Servers.PublicCurrent(disco) + curr, err := c.Servers.PublicCurrent(c.discoMan) if err != nil { return nil, i18nerr.WrapInternal(err, "The current server could not be retrieved") } @@ -579,9 +570,7 @@ func (c *Client) Cleanup(ck *cookie.Cookie) error { if err != nil { return i18nerr.WrapInternal(err, "No OAuth tokens were found when cleaning up the connection") } - disco, release := c.discoMan.Discovery(true) - defer release() - auth, err := srv.ServerWithCallbacks(ck.Context(), disco, tok, true) + auth, err := srv.ServerWithCallbacks(ck.Context(), c.discoMan, tok, true) if err != nil { return i18nerr.WrapInternal(err, "The server was unable to be retrieved when cleaning up the connection") } @@ -632,9 +621,7 @@ func (c *Client) RenewSession(ck *cookie.Cookie) error { previousState := c.FSM.Current // getting a server with no tokens means re-authorize - disco, release := c.discoMan.Discovery(true) - defer release() - _, err = srv.ServerWithCallbacks(ck.Context(), disco, nil, false) + _, err = srv.ServerWithCallbacks(ck.Context(), c.discoMan, nil, false) if err != nil { c.FSM.GoTransition(previousState) //nolint:errcheck return i18nerr.WrapInternal(err, "The server was unable to be retrieved when renewing the session") diff --git a/client/discovery.go b/client/discovery.go index 2816d10..6cc76f6 100644 --- a/client/discovery.go +++ b/client/discovery.go @@ -4,11 +4,8 @@ import ( "context" "sort" "strings" - "sync" "github.com/eduvpn/eduvpn-common/i18nerr" - "github.com/eduvpn/eduvpn-common/internal/discovery" - "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/types/cookie" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" ) @@ -138,79 +135,3 @@ func (c *Client) DiscoveryStartup(cb func()) error { c.discoMan.Startup(context.Background(), fcb) return nil } - -type DiscoManager struct { - disco *discovery.Discovery - - cancel context.CancelFunc - mu sync.RWMutex - wait sync.WaitGroup -} - -func (m *DiscoManager) lock(write bool) { - log.Logger.Debugf("Locking write: %v", write) - if write { - m.mu.Lock() - return - } - m.mu.RLock() -} - -func (m *DiscoManager) unlock(write bool) { - log.Logger.Debugf("Unlocking write: %v", write) - if write { - m.mu.Unlock() - return - } - m.mu.RUnlock() -} - -func (m *DiscoManager) Discovery(write bool) (*discovery.Discovery, func()) { - log.Logger.Debugf("Requesting discovery write: %v", write) - if write { - m.wait.Wait() - } - m.lock(write) - return m.disco, func() { - m.unlock(write) - } -} - -func (m *DiscoManager) Cancel() { - if m.cancel != nil { - m.cancel() - } - m.wait.Wait() -} - -func (m *DiscoManager) Startup(ctx context.Context, cb func()) { - ctx, cancel := context.WithCancel(ctx) - m.cancel = cancel - m.wait.Add(1) - go func() { - m.lock(false) - discoCopy, err := m.disco.Copy() - if err != nil { - log.Logger.Warningf("internal error, failed to clone discovery, %v", err) - return - } - m.unlock(false) - // we already log the warning - discoCopy.Servers(ctx) //nolint:errcheck - - m.lock(true) - m.disco.UpdateServers(discoCopy) - m.unlock(true) - m.wait.Done() - - select { - case <-ctx.Done(): - return - default: - if cb == nil { - return - } - cb() - } - }() -} diff --git a/internal/discovery/manager.go b/internal/discovery/manager.go new file mode 100644 index 0000000..b84e770 --- /dev/null +++ b/internal/discovery/manager.go @@ -0,0 +1,88 @@ +package discovery + +import ( + "context" + "sync" + + "github.com/eduvpn/eduvpn-common/internal/log" +) + +type Manager struct { + disco *Discovery + + cancel context.CancelFunc + mu sync.RWMutex + wait sync.WaitGroup +} + +func NewManager(disco *Discovery) *Manager { + return &Manager{disco: disco} +} + +func (m *Manager) lock(write bool) { + log.Logger.Debugf("Locking write: %v", write) + if write { + m.mu.Lock() + return + } + m.mu.RLock() +} + +func (m *Manager) unlock(write bool) { + log.Logger.Debugf("Unlocking write: %v", write) + if write { + m.mu.Unlock() + return + } + m.mu.RUnlock() +} + +func (m *Manager) Discovery(write bool) (*Discovery, func()) { + log.Logger.Debugf("Requesting discovery write: %v", write) + if write { + m.wait.Wait() + } + m.lock(write) + return m.disco, func() { + m.unlock(write) + } +} + +func (m *Manager) Cancel() { + if m.cancel != nil { + m.cancel() + } + m.wait.Wait() +} + +func (m *Manager) Startup(ctx context.Context, cb func()) { + ctx, cancel := context.WithCancel(ctx) + m.cancel = cancel + m.wait.Add(1) + go func() { + m.lock(false) + discoCopy, err := m.disco.Copy() + if err != nil { + log.Logger.Warningf("internal error, failed to clone discovery, %v", err) + return + } + m.unlock(false) + // we already log the warning + discoCopy.Servers(ctx) //nolint:errcheck + + m.lock(true) + m.disco.UpdateServers(discoCopy) + m.unlock(true) + m.wait.Done() + + select { + case <-ctx.Done(): + return + default: + if cb == nil { + return + } + cb() + } + }() +} diff --git a/internal/server/institute.go b/internal/server/institute.go index caae004..4ddf939 100644 --- a/internal/server/institute.go +++ b/internal/server/institute.go @@ -17,12 +17,15 @@ import ( // `disco` are the discovery servers // `id` is the identifier for the server, the base url // `ot` specifies specifies the start time OAuth was already triggered -func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, id string, ot *int64) error { +func (s *Servers) AddInstitute(ctx context.Context, discom *discovery.Manager, id string, ot *int64) error { // This is basically done to double check if the server is part of the institute access section of disco + disco, release := discom.Discovery(false) dsrv, err := disco.ServerByURL(id, "institute_access") if err != nil { + release() return err } + release() sd := api.ServerData{ ID: dsrv.BaseURL, @@ -67,12 +70,15 @@ func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, // `disco` are the discovery servers // `tok` are the tokens such that we do not have to trigger auth // `disableAuth` is true when auth should never be triggered -func (s *Servers) GetInstitute(ctx context.Context, id string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) { +func (s *Servers) GetInstitute(ctx context.Context, id string, discom *discovery.Manager, tok *eduoauth.Token, disableAuth bool) (*Server, error) { + disco, release := discom.Discovery(false) // This is basically done to double check if the server is part of the institute access section of disco dsrv, err := disco.ServerByURL(id, "institute_access") if err != nil { + release() return nil, err } + release() // Get the server from the config _, err = s.config.GetServer(dsrv.BaseURL, server.TypeInstituteAccess) diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 990ceb3..9b4b873 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -19,14 +19,17 @@ import ( // `disco` are the discovery servers // `orgID` is the organiztaion ID // `ot` specifies specifies the start time OAuth was already triggered -func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, orgID string, ot *int64) error { +func (s *Servers) AddSecure(ctx context.Context, discom *discovery.Manager, orgID string, ot *int64) error { if s.config.HasSecureInternet() { return errors.New("a secure internet server already exists") } + disco, release := discom.Discovery(false) dorg, dsrv, err := disco.SecureHomeArgs(orgID) if err != nil { + release() return err } + release() sd := api.ServerData{ ID: dorg.OrgID, @@ -34,11 +37,13 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org BaseWK: dsrv.BaseURL, BaseAuthWK: dsrv.BaseURL, ProcessAuth: func(ctx context.Context, url string) (string, error) { + newd, release := discom.Discovery(true) + defer release() // the only thing we can do is log warn // this is already done in the functions - disco.Servers(ctx) //nolint:errcheck - disco.Organizations(ctx) //nolint:errcheck - updorg, updsrv, err := disco.SecureHomeArgs(orgID) + newd.Servers(ctx) //nolint:errcheck + newd.Organizations(ctx) //nolint:errcheck + updorg, updsrv, err := newd.SecureHomeArgs(orgID) if err != nil { return "", err } @@ -84,21 +89,25 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org // `disco` are the discovery servers // `tok` are the tokens such that the server can be found without triggering auth // `disableAuth` is set to true when authorization should not be triggered -func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) { +func (s *Servers) GetSecure(ctx context.Context, orgID string, discom *discovery.Manager, tok *eduoauth.Token, disableAuth bool) (*Server, error) { srv, err := s.config.GetServer(orgID, server.TypeSecureInternet) if err != nil { return nil, err } + disco, release := discom.Discovery(false) dorg, dhome, err := disco.SecureHomeArgs(orgID) if err != nil { + release() return nil, err } dloc, err := disco.ServerByCountryCode(srv.CountryCode) if err != nil { + release() return nil, err } + release() sd := api.ServerData{ ID: dorg.OrgID, @@ -106,13 +115,15 @@ func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery. BaseWK: dloc.BaseURL, BaseAuthWK: dhome.BaseURL, ProcessAuth: func(ctx context.Context, url string) (string, error) { + newd, release := discom.Discovery(true) + defer release() // the only thing we can do is log warn // this is already done in the functions - disco.MarkServersExpired() - disco.Servers(ctx) //nolint:errcheck - disco.MarkOrganizationsExpired() - disco.Organizations(ctx) //nolint:errcheck - updorg, updsrv, err := disco.SecureHomeArgs(orgID) + newd.MarkServersExpired() + newd.Servers(ctx) //nolint:errcheck + newd.MarkOrganizationsExpired() + newd.Organizations(ctx) //nolint:errcheck + updorg, updsrv, err := newd.SecureHomeArgs(orgID) if err != nil { return "", err } diff --git a/internal/server/servers.go b/internal/server/servers.go index 8285801..d06c37d 100644 --- a/internal/server/servers.go +++ b/internal/server/servers.go @@ -54,12 +54,12 @@ type CurrentServer struct { } // ServerWithCallbacks gets the current server as a server struct and triggers callbacks as needed -func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, disco *discovery.Discovery, tokens *eduoauth.Token, disableAuth bool) (*Server, error) { +func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, discom *discovery.Manager, tokens *eduoauth.Token, disableAuth bool) (*Server, error) { switch cs.Key.T { case srvtypes.TypeInstituteAccess: - return cs.srvs.GetInstitute(ctx, cs.Key.ID, disco, tokens, disableAuth) + return cs.srvs.GetInstitute(ctx, cs.Key.ID, discom, tokens, disableAuth) case srvtypes.TypeSecureInternet: - return cs.srvs.GetSecure(ctx, cs.Key.ID, disco, tokens, disableAuth) + return cs.srvs.GetSecure(ctx, cs.Key.ID, discom, tokens, disableAuth) case srvtypes.TypeCustom: return cs.srvs.GetCustom(ctx, cs.Key.ID, tokens, disableAuth) default: @@ -89,7 +89,9 @@ func (s *Servers) CurrentServer() (*CurrentServer, error) { } // PublicCurrent gets the current server into a type that we can return to the client -func (s *Servers) PublicCurrent(disco *discovery.Discovery) (*srvtypes.Current, error) { +func (s *Servers) PublicCurrent(discom *discovery.Manager) (*srvtypes.Current, error) { + disco, release := discom.Discovery(false) + defer release() return s.config.PublicCurrent(disco) } |
