From 132782f44603dfdc3b1d875d632f786109ee09a2 Mon Sep 17 00:00:00 2001 From: Jeroen Wijenbergh Date: Fri, 29 Aug 2025 14:36:00 +0200 Subject: Discovery: Remove manager and DiscoveryStartup --- internal/discovery/discovery.go | 59 +++++++++++++++---------- internal/discovery/manager.go | 92 --------------------------------------- internal/server/institute.go | 10 +---- internal/server/secureinternet.go | 31 +++++-------- internal/server/servers.go | 10 ++--- 5 files changed, 52 insertions(+), 150 deletions(-) delete mode 100644 internal/discovery/manager.go (limited to 'internal') diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 80a69eb..512756b 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -8,6 +8,7 @@ import ( "fmt" "log/slog" "net/http" + "sync" "time" httpw "codeberg.org/eduVPN/eduvpn-common/internal/http" @@ -83,6 +84,8 @@ func (s *Server) Score(search string) int { // Discovery is the main structure used for this package. type Discovery struct { + // mu is the read write mutex that protects the struct from concurrent access + mu sync.RWMutex // The httpClient for sending HTTP requests httpClient *httpw.Client @@ -174,23 +177,27 @@ func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousV // MarkOrganizationsExpired marks the organizations as expired func (discovery *Discovery) MarkOrganizationsExpired() { + discovery.mu.Lock() + defer discovery.mu.Unlock() // Re-initialize the timestamp to zero discovery.OrganizationList.Timestamp = time.Time{} } // MarkServersExpired marks the servers as expired func (discovery *Discovery) MarkServersExpired() { + discovery.mu.Lock() + defer discovery.mu.Unlock() // Re-initialize the timestamp to zero discovery.ServerList.Timestamp = time.Time{} } -// DetermineOrganizationsUpdate returns a boolean indicating whether or not the discovery organizations should be updated +// determineOrganizationsUpdate returns a boolean indicating whether or not the discovery organizations should be updated // https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md // - [IMPLEMENTED] on "first launch" when offering the search for "Institute Access" and "Organizations"; // - [IMPLEMENTED in client/client.go and here] when the user tries to add new server AND the user did NOT yet choose an organization before; Implemented in Register() // - [IMPLEMENTED in client/client.go] when the authorization for the server associated with an already chosen organization is triggered, e.g. after expiry or revocation. // - [IMPLEMENTED here] NOTE: when the org_id that the user chose previously is no longer available in organization_list.json the application should ask the user to choose their organization (again). This can occur for example when the organization replaced their identity provider, uses a different domain after rebranding or simply ceased to exist. -func (discovery *Discovery) DetermineOrganizationsUpdate() bool { +func (discovery *Discovery) determineOrganizationsUpdate() bool { if discovery.OrganizationList.Timestamp.IsZero() { return true } @@ -202,6 +209,8 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool { // SecureLocationList returns a slice of all the available locations. func (discovery *Discovery) SecureLocationList() []string { + discovery.mu.RLock() + defer discovery.mu.RUnlock() var loc []string for _, srv := range discovery.ServerList.List { if srv.Type == "secure_internet" { @@ -217,6 +226,8 @@ func (discovery *Discovery) ServerByURL( baseURL string, srvType string, ) (*Server, error) { + discovery.mu.RLock() + defer discovery.mu.RUnlock() for _, currentServer := range discovery.ServerList.List { if currentServer.BaseURL == baseURL && currentServer.Type == srvType { return ¤tServer, nil @@ -237,6 +248,8 @@ func (cnf *ErrCountryNotFound) Error() string { // ServerByCountryCode returns the discovery server by the country code // An error is returned if and only if nil is returned for the server. func (discovery *Discovery) ServerByCountryCode(countryCode string) (*Server, error) { + discovery.mu.RLock() + defer discovery.mu.RUnlock() for _, srv := range discovery.ServerList.List { if srv.CountryCode == countryCode && srv.Type == "secure_internet" { return &srv, nil @@ -261,8 +274,10 @@ func (discovery *Discovery) orgByID(orgID string) (*Organization, error) { // - The secure internet server itself // An error is returned if and only if nil is returned for the organization. func (discovery *Discovery) SecureHomeArgs(orgID string) (*Organization, *Server, error) { + discovery.mu.RLock() org, err := discovery.orgByID(orgID) if err != nil { + discovery.mu.RUnlock() discovery.MarkOrganizationsExpired() return nil, nil, err } @@ -270,17 +285,19 @@ func (discovery *Discovery) SecureHomeArgs(orgID string) (*Organization, *Server // Get a server with the base url srv, err := discovery.ServerByURL(org.SecureInternetHome, "secure_internet") if err != nil { + discovery.mu.RUnlock() discovery.MarkOrganizationsExpired() return nil, nil, err } + discovery.mu.RUnlock() return org, srv, nil } -// DetermineServersUpdate returns whether or not the discovery servers should be updated by contacting the discovery server +// determineServersUpdate returns whether or not the discovery servers should be updated by contacting the discovery server // https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md // - [Implemented] The application MUST always fetch the server_list.json at application start. // - The application MAY refresh the server_list.json periodically, e.g. once every hour. -func (discovery *Discovery) DetermineServersUpdate() bool { +func (discovery *Discovery) determineServersUpdate() bool { // No servers, we should update if discovery.ServerList.Timestamp.IsZero() { return true @@ -305,9 +322,14 @@ func (discovery *Discovery) cachedOrgs() *Organizations { // If there was an error, a cached copy is returned if available. // cache is set to true if there should be no network call done func (discovery *Discovery) Organizations(ctx context.Context, cache bool) (*Organizations, bool, error) { - if cache || !discovery.DetermineOrganizationsUpdate() { + discovery.mu.RLock() + if cache || !discovery.determineOrganizationsUpdate() { + discovery.mu.RUnlock() return discovery.cachedOrgs(), false, nil } + discovery.mu.RUnlock() + discovery.mu.Lock() + defer discovery.mu.Unlock() file := "organization_list.json" var jsonDecode Organizations update, err := discovery.file(ctx, file, discovery.OrganizationList.Version, discovery.OrganizationList.UpdateHeader, &jsonDecode) @@ -349,9 +371,14 @@ func (discovery *Discovery) cachedServers() *Servers { // If there was an error, a cached copy is returned if available. // cache is set to true if there should be no network call done func (discovery *Discovery) Servers(ctx context.Context, cache bool) (*Servers, bool, error) { - if cache || !discovery.DetermineServersUpdate() { + discovery.mu.RLock() + if cache || !discovery.determineServersUpdate() { + discovery.mu.RUnlock() return discovery.cachedServers(), false, nil } + discovery.mu.RUnlock() + discovery.mu.Lock() + defer discovery.mu.Unlock() file := "server_list.json" var jsonDecode Servers update, err := discovery.file(ctx, file, discovery.ServerList.Version, discovery.ServerList.UpdateHeader, &jsonDecode) @@ -397,29 +424,15 @@ func (discovery *Discovery) UpdateOrganizations(other Organizations) { } } -// Copy creates a deep-copy for the discovery struct -// It does this by marshalling and unmarshalling it as JSON -func (discovery *Discovery) Copy() (Discovery, error) { - var dest Discovery - b, err := json.Marshal(discovery) - if err != nil { - return dest, err - } - - err = json.Unmarshal(b, &dest) - if err != nil { - return dest, err - } - - return dest, nil -} - // Fill makes sure that the cache is filled with the embedded discovery func (discovery *Discovery) Fill() error { if !HasCache { return nil } + discovery.mu.Lock() + defer discovery.mu.Unlock() + var err error var es Servers err = errors.Join(err, json.Unmarshal(eServers, &es)) diff --git a/internal/discovery/manager.go b/internal/discovery/manager.go deleted file mode 100644 index 4fb4f8e..0000000 --- a/internal/discovery/manager.go +++ /dev/null @@ -1,92 +0,0 @@ -package discovery - -import ( - "context" - "log/slog" - "sync" -) - -// Manager is the discovery struct that is cached -// with some bookkeeping to avoid race conditions -type Manager struct { - disco *Discovery - - cancel context.CancelFunc - mu sync.RWMutex - wait sync.WaitGroup -} - -// NewManager creates a new Discovery manager -func NewManager(disco *Discovery) *Manager { - return &Manager{disco: disco} -} - -func (m *Manager) lock(write bool) { - if write { - m.mu.Lock() - return - } - m.mu.RLock() -} - -func (m *Manager) unlock(write bool) { - if write { - m.mu.Unlock() - return - } - m.mu.RUnlock() -} - -// Discovery gets the cached discovery -// `write` is true if discovery will be written to -func (m *Manager) Discovery(write bool) (*Discovery, func()) { - if write { - m.wait.Wait() - } - m.lock(write) - return m.disco, func() { - m.unlock(write) - } -} - -// Cancel aborts the running discovery startup process -func (m *Manager) Cancel() { - if m.cancel != nil { - m.cancel() - } - m.wait.Wait() -} - -// Startup handles the discovery process in the background -// It's called Startup because it's called when the lib is initialised -func (m *Manager) Startup(ctx context.Context, cb func()) { - ctx, cancel := context.WithCancel(ctx) - m.cancel = cancel - m.wait.Add(1) - go func() { - m.lock(false) - discoCopy, err := m.disco.Copy() - if err != nil { - slog.Warn("failed to clone discovery", "error", err) - return - } - m.unlock(false) - // we already log the warning - discoCopy.Servers(ctx, false) //nolint:errcheck - - m.lock(true) - m.disco.UpdateServers(discoCopy.ServerList) - m.unlock(true) - m.wait.Done() - - select { - case <-ctx.Done(): - return - default: - if cb == nil { - return - } - cb() - } - }() -} diff --git a/internal/server/institute.go b/internal/server/institute.go index 348e895..0a7ef3a 100644 --- a/internal/server/institute.go +++ b/internal/server/institute.go @@ -17,15 +17,12 @@ import ( // `disco` are the discovery servers // `id` is the identifier for the server, the base url // `ot` specifies specifies the start time OAuth was already triggered -func (s *Servers) AddInstitute(ctx context.Context, discom *discovery.Manager, id string, ot *int64) error { +func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, id string, ot *int64) error { // This is basically done to double check if the server is part of the institute access section of disco - disco, release := discom.Discovery(false) dsrv, err := disco.ServerByURL(id, "institute_access") if err != nil { - release() return err } - release() sd := api.ServerData{ ID: dsrv.BaseURL, @@ -70,15 +67,12 @@ func (s *Servers) AddInstitute(ctx context.Context, discom *discovery.Manager, i // `disco` are the discovery servers // `tok` are the tokens such that we do not have to trigger auth // `disableAuth` is true when auth should never be triggered -func (s *Servers) GetInstitute(ctx context.Context, id string, discom *discovery.Manager, tok *eduoauth.Token, disableAuth bool) (*Server, error) { - disco, release := discom.Discovery(false) +func (s *Servers) GetInstitute(ctx context.Context, id string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) { // This is basically done to double check if the server is part of the institute access section of disco dsrv, err := disco.ServerByURL(id, "institute_access") if err != nil { - release() return nil, err } - release() // Get the server from the config _, err = s.config.GetServer(dsrv.BaseURL, server.TypeInstituteAccess) diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 69b1e97..18a5e78 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -45,17 +45,14 @@ func ReplaceWAYF(template string, authURL string, orgID string) string { // `disco` are the discovery servers // `orgID` is the organiztaion ID // `ot` specifies specifies the start time OAuth was already triggered -func (s *Servers) AddSecure(ctx context.Context, discom *discovery.Manager, orgID string, ot *int64) error { +func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, orgID string, ot *int64) error { if s.config.HasSecureInternet() { return errors.New("a secure internet server already exists") } - disco, release := discom.Discovery(false) dorg, dsrv, err := disco.SecureHomeArgs(orgID) if err != nil { - release() return err } - release() sd := api.ServerData{ ID: dorg.OrgID, @@ -63,13 +60,11 @@ func (s *Servers) AddSecure(ctx context.Context, discom *discovery.Manager, orgI BaseWK: dsrv.BaseURL, BaseAuthWK: dsrv.BaseURL, ProcessAuth: func(ctx context.Context, url string) (string, error) { - newd, release := discom.Discovery(true) - defer release() // the only thing we can do is log warn // this is already done in the functions - newd.Servers(ctx, false) //nolint:errcheck - newd.Organizations(ctx, false) //nolint:errcheck - updorg, updsrv, err := newd.SecureHomeArgs(orgID) + disco.Servers(ctx, false) //nolint:errcheck + disco.Organizations(ctx, false) //nolint:errcheck + updorg, updsrv, err := disco.SecureHomeArgs(orgID) if err != nil { return "", err } @@ -115,25 +110,21 @@ func (s *Servers) AddSecure(ctx context.Context, discom *discovery.Manager, orgI // `disco` are the discovery servers // `tok` are the tokens such that the server can be found without triggering auth // `disableAuth` is set to true when authorization should not be triggered -func (s *Servers) GetSecure(ctx context.Context, orgID string, discom *discovery.Manager, tok *eduoauth.Token, disableAuth bool) (*Server, error) { +func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) { srv, err := s.config.GetServer(orgID, server.TypeSecureInternet) if err != nil { return nil, err } - disco, release := discom.Discovery(false) dorg, dhome, err := disco.SecureHomeArgs(orgID) if err != nil { - release() return nil, err } dloc, err := disco.ServerByCountryCode(srv.CountryCode) if err != nil { - release() return nil, err } - release() sd := api.ServerData{ ID: dorg.OrgID, @@ -141,15 +132,13 @@ func (s *Servers) GetSecure(ctx context.Context, orgID string, discom *discovery BaseWK: dloc.BaseURL, BaseAuthWK: dhome.BaseURL, ProcessAuth: func(ctx context.Context, url string) (string, error) { - newd, release := discom.Discovery(true) - defer release() // the only thing we can do is log warn // this is already done in the functions - newd.MarkServersExpired() - newd.Servers(ctx, false) //nolint:errcheck - newd.MarkOrganizationsExpired() - newd.Organizations(ctx, false) //nolint:errcheck - updorg, updsrv, err := newd.SecureHomeArgs(orgID) + disco.MarkServersExpired() + disco.Servers(ctx, false) //nolint:errcheck + disco.MarkOrganizationsExpired() + disco.Organizations(ctx, false) //nolint:errcheck + updorg, updsrv, err := disco.SecureHomeArgs(orgID) if err != nil { return "", err } diff --git a/internal/server/servers.go b/internal/server/servers.go index d1aed97..60b9277 100644 --- a/internal/server/servers.go +++ b/internal/server/servers.go @@ -54,12 +54,12 @@ type CurrentServer struct { } // ServerWithCallbacks gets the current server as a server struct and triggers callbacks as needed -func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, discom *discovery.Manager, tokens *eduoauth.Token, disableAuth bool) (*Server, error) { +func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, disco *discovery.Discovery, tokens *eduoauth.Token, disableAuth bool) (*Server, error) { switch cs.Key.T { case srvtypes.TypeInstituteAccess: - return cs.srvs.GetInstitute(ctx, cs.Key.ID, discom, tokens, disableAuth) + return cs.srvs.GetInstitute(ctx, cs.Key.ID, disco, tokens, disableAuth) case srvtypes.TypeSecureInternet: - return cs.srvs.GetSecure(ctx, cs.Key.ID, discom, tokens, disableAuth) + return cs.srvs.GetSecure(ctx, cs.Key.ID, disco, tokens, disableAuth) case srvtypes.TypeCustom: return cs.srvs.GetCustom(ctx, cs.Key.ID, tokens, disableAuth) default: @@ -89,9 +89,7 @@ func (s *Servers) CurrentServer() (*CurrentServer, error) { } // PublicCurrent gets the current server into a type that we can return to the client -func (s *Servers) PublicCurrent(discom *discovery.Manager) (*srvtypes.Current, error) { - disco, release := discom.Discovery(false) - defer release() +func (s *Servers) PublicCurrent(disco *discovery.Discovery) (*srvtypes.Current, error) { return s.config.PublicCurrent(disco) } -- cgit v1.2.3