diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2024-07-17 11:20:59 +0200 |
|---|---|---|
| committer | Jeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com> | 2024-07-17 14:00:03 +0000 |
| commit | 7276108171b1c6af417ce5ae86ef0221280932c1 (patch) | |
| tree | 3decc148de89b9ae3b3b6d04de35fea71ffb4c3b /internal | |
| parent | 5362dd10bf13f8087a13a4ab5441efa04755fcc7 (diff) | |
Client + Server: Pass discovery manager and lock when needed
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/discovery/manager.go | 88 | ||||
| -rw-r--r-- | internal/server/institute.go | 10 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 31 | ||||
| -rw-r--r-- | internal/server/servers.go | 10 |
4 files changed, 123 insertions, 16 deletions
diff --git a/internal/discovery/manager.go b/internal/discovery/manager.go new file mode 100644 index 0000000..b84e770 --- /dev/null +++ b/internal/discovery/manager.go @@ -0,0 +1,88 @@ +package discovery + +import ( + "context" + "sync" + + "github.com/eduvpn/eduvpn-common/internal/log" +) + +type Manager struct { + disco *Discovery + + cancel context.CancelFunc + mu sync.RWMutex + wait sync.WaitGroup +} + +func NewManager(disco *Discovery) *Manager { + return &Manager{disco: disco} +} + +func (m *Manager) lock(write bool) { + log.Logger.Debugf("Locking write: %v", write) + if write { + m.mu.Lock() + return + } + m.mu.RLock() +} + +func (m *Manager) unlock(write bool) { + log.Logger.Debugf("Unlocking write: %v", write) + if write { + m.mu.Unlock() + return + } + m.mu.RUnlock() +} + +func (m *Manager) Discovery(write bool) (*Discovery, func()) { + log.Logger.Debugf("Requesting discovery write: %v", write) + if write { + m.wait.Wait() + } + m.lock(write) + return m.disco, func() { + m.unlock(write) + } +} + +func (m *Manager) Cancel() { + if m.cancel != nil { + m.cancel() + } + m.wait.Wait() +} + +func (m *Manager) Startup(ctx context.Context, cb func()) { + ctx, cancel := context.WithCancel(ctx) + m.cancel = cancel + m.wait.Add(1) + go func() { + m.lock(false) + discoCopy, err := m.disco.Copy() + if err != nil { + log.Logger.Warningf("internal error, failed to clone discovery, %v", err) + return + } + m.unlock(false) + // we already log the warning + discoCopy.Servers(ctx) //nolint:errcheck + + m.lock(true) + m.disco.UpdateServers(discoCopy) + m.unlock(true) + m.wait.Done() + + select { + case <-ctx.Done(): + return + default: + if cb == nil { + return + } + cb() + } + }() +} diff --git a/internal/server/institute.go b/internal/server/institute.go index caae004..4ddf939 100644 --- a/internal/server/institute.go +++ b/internal/server/institute.go @@ -17,12 +17,15 @@ import ( // `disco` are the discovery servers // `id` is the identifier for the server, the base url // `ot` specifies specifies the start time OAuth was already triggered -func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, id string, ot *int64) error { +func (s *Servers) AddInstitute(ctx context.Context, discom *discovery.Manager, id string, ot *int64) error { // This is basically done to double check if the server is part of the institute access section of disco + disco, release := discom.Discovery(false) dsrv, err := disco.ServerByURL(id, "institute_access") if err != nil { + release() return err } + release() sd := api.ServerData{ ID: dsrv.BaseURL, @@ -67,12 +70,15 @@ func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, // `disco` are the discovery servers // `tok` are the tokens such that we do not have to trigger auth // `disableAuth` is true when auth should never be triggered -func (s *Servers) GetInstitute(ctx context.Context, id string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) { +func (s *Servers) GetInstitute(ctx context.Context, id string, discom *discovery.Manager, tok *eduoauth.Token, disableAuth bool) (*Server, error) { + disco, release := discom.Discovery(false) // This is basically done to double check if the server is part of the institute access section of disco dsrv, err := disco.ServerByURL(id, "institute_access") if err != nil { + release() return nil, err } + release() // Get the server from the config _, err = s.config.GetServer(dsrv.BaseURL, server.TypeInstituteAccess) diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 990ceb3..9b4b873 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -19,14 +19,17 @@ import ( // `disco` are the discovery servers // `orgID` is the organiztaion ID // `ot` specifies specifies the start time OAuth was already triggered -func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, orgID string, ot *int64) error { +func (s *Servers) AddSecure(ctx context.Context, discom *discovery.Manager, orgID string, ot *int64) error { if s.config.HasSecureInternet() { return errors.New("a secure internet server already exists") } + disco, release := discom.Discovery(false) dorg, dsrv, err := disco.SecureHomeArgs(orgID) if err != nil { + release() return err } + release() sd := api.ServerData{ ID: dorg.OrgID, @@ -34,11 +37,13 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org BaseWK: dsrv.BaseURL, BaseAuthWK: dsrv.BaseURL, ProcessAuth: func(ctx context.Context, url string) (string, error) { + newd, release := discom.Discovery(true) + defer release() // the only thing we can do is log warn // this is already done in the functions - disco.Servers(ctx) //nolint:errcheck - disco.Organizations(ctx) //nolint:errcheck - updorg, updsrv, err := disco.SecureHomeArgs(orgID) + newd.Servers(ctx) //nolint:errcheck + newd.Organizations(ctx) //nolint:errcheck + updorg, updsrv, err := newd.SecureHomeArgs(orgID) if err != nil { return "", err } @@ -84,21 +89,25 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org // `disco` are the discovery servers // `tok` are the tokens such that the server can be found without triggering auth // `disableAuth` is set to true when authorization should not be triggered -func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) { +func (s *Servers) GetSecure(ctx context.Context, orgID string, discom *discovery.Manager, tok *eduoauth.Token, disableAuth bool) (*Server, error) { srv, err := s.config.GetServer(orgID, server.TypeSecureInternet) if err != nil { return nil, err } + disco, release := discom.Discovery(false) dorg, dhome, err := disco.SecureHomeArgs(orgID) if err != nil { + release() return nil, err } dloc, err := disco.ServerByCountryCode(srv.CountryCode) if err != nil { + release() return nil, err } + release() sd := api.ServerData{ ID: dorg.OrgID, @@ -106,13 +115,15 @@ func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery. BaseWK: dloc.BaseURL, BaseAuthWK: dhome.BaseURL, ProcessAuth: func(ctx context.Context, url string) (string, error) { + newd, release := discom.Discovery(true) + defer release() // the only thing we can do is log warn // this is already done in the functions - disco.MarkServersExpired() - disco.Servers(ctx) //nolint:errcheck - disco.MarkOrganizationsExpired() - disco.Organizations(ctx) //nolint:errcheck - updorg, updsrv, err := disco.SecureHomeArgs(orgID) + newd.MarkServersExpired() + newd.Servers(ctx) //nolint:errcheck + newd.MarkOrganizationsExpired() + newd.Organizations(ctx) //nolint:errcheck + updorg, updsrv, err := newd.SecureHomeArgs(orgID) if err != nil { return "", err } diff --git a/internal/server/servers.go b/internal/server/servers.go index 8285801..d06c37d 100644 --- a/internal/server/servers.go +++ b/internal/server/servers.go @@ -54,12 +54,12 @@ type CurrentServer struct { } // ServerWithCallbacks gets the current server as a server struct and triggers callbacks as needed -func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, disco *discovery.Discovery, tokens *eduoauth.Token, disableAuth bool) (*Server, error) { +func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, discom *discovery.Manager, tokens *eduoauth.Token, disableAuth bool) (*Server, error) { switch cs.Key.T { case srvtypes.TypeInstituteAccess: - return cs.srvs.GetInstitute(ctx, cs.Key.ID, disco, tokens, disableAuth) + return cs.srvs.GetInstitute(ctx, cs.Key.ID, discom, tokens, disableAuth) case srvtypes.TypeSecureInternet: - return cs.srvs.GetSecure(ctx, cs.Key.ID, disco, tokens, disableAuth) + return cs.srvs.GetSecure(ctx, cs.Key.ID, discom, tokens, disableAuth) case srvtypes.TypeCustom: return cs.srvs.GetCustom(ctx, cs.Key.ID, tokens, disableAuth) default: @@ -89,7 +89,9 @@ func (s *Servers) CurrentServer() (*CurrentServer, error) { } // PublicCurrent gets the current server into a type that we can return to the client -func (s *Servers) PublicCurrent(disco *discovery.Discovery) (*srvtypes.Current, error) { +func (s *Servers) PublicCurrent(discom *discovery.Manager) (*srvtypes.Current, error) { + disco, release := discom.Discovery(false) + defer release() return s.config.PublicCurrent(disco) } |
