summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorJeroen Wijenbergh <jeroen.wijenbergh@geant.org>2025-08-29 14:36:00 +0200
committerJeroen Wijenbergh <jeroen.wijenbergh@geant.org>2025-08-29 14:40:25 +0200
commit132782f44603dfdc3b1d875d632f786109ee09a2 (patch)
tree71cd1cefee78636167c560c80c8bc79c8982fe26 /internal
parent4a23134e1e5d70a9c8c5857790dbf27585ca3b1f (diff)
Discovery: Remove manager and DiscoveryStartup
Diffstat (limited to 'internal')
-rw-r--r--internal/discovery/discovery.go59
-rw-r--r--internal/discovery/manager.go92
-rw-r--r--internal/server/institute.go10
-rw-r--r--internal/server/secureinternet.go31
-rw-r--r--internal/server/servers.go10
5 files changed, 52 insertions, 150 deletions
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index 80a69eb..512756b 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -8,6 +8,7 @@ import (
"fmt"
"log/slog"
"net/http"
+ "sync"
"time"
httpw "codeberg.org/eduVPN/eduvpn-common/internal/http"
@@ -83,6 +84,8 @@ func (s *Server) Score(search string) int {
// Discovery is the main structure used for this package.
type Discovery struct {
+ // mu is the read write mutex that protects the struct from concurrent access
+ mu sync.RWMutex
// The httpClient for sending HTTP requests
httpClient *httpw.Client
@@ -174,23 +177,27 @@ func (discovery *Discovery) file(ctx context.Context, jsonFile string, previousV
// MarkOrganizationsExpired marks the organizations as expired
func (discovery *Discovery) MarkOrganizationsExpired() {
+ discovery.mu.Lock()
+ defer discovery.mu.Unlock()
// Re-initialize the timestamp to zero
discovery.OrganizationList.Timestamp = time.Time{}
}
// MarkServersExpired marks the servers as expired
func (discovery *Discovery) MarkServersExpired() {
+ discovery.mu.Lock()
+ defer discovery.mu.Unlock()
// Re-initialize the timestamp to zero
discovery.ServerList.Timestamp = time.Time{}
}
-// DetermineOrganizationsUpdate returns a boolean indicating whether or not the discovery organizations should be updated
+// determineOrganizationsUpdate returns a boolean indicating whether or not the discovery organizations should be updated
// https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md
// - [IMPLEMENTED] on "first launch" when offering the search for "Institute Access" and "Organizations";
// - [IMPLEMENTED in client/client.go and here] when the user tries to add new server AND the user did NOT yet choose an organization before; Implemented in Register()
// - [IMPLEMENTED in client/client.go] when the authorization for the server associated with an already chosen organization is triggered, e.g. after expiry or revocation.
// - [IMPLEMENTED here] NOTE: when the org_id that the user chose previously is no longer available in organization_list.json the application should ask the user to choose their organization (again). This can occur for example when the organization replaced their identity provider, uses a different domain after rebranding or simply ceased to exist.
-func (discovery *Discovery) DetermineOrganizationsUpdate() bool {
+func (discovery *Discovery) determineOrganizationsUpdate() bool {
if discovery.OrganizationList.Timestamp.IsZero() {
return true
}
@@ -202,6 +209,8 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool {
// SecureLocationList returns a slice of all the available locations.
func (discovery *Discovery) SecureLocationList() []string {
+ discovery.mu.RLock()
+ defer discovery.mu.RUnlock()
var loc []string
for _, srv := range discovery.ServerList.List {
if srv.Type == "secure_internet" {
@@ -217,6 +226,8 @@ func (discovery *Discovery) ServerByURL(
baseURL string,
srvType string,
) (*Server, error) {
+ discovery.mu.RLock()
+ defer discovery.mu.RUnlock()
for _, currentServer := range discovery.ServerList.List {
if currentServer.BaseURL == baseURL && currentServer.Type == srvType {
return &currentServer, nil
@@ -237,6 +248,8 @@ func (cnf *ErrCountryNotFound) Error() string {
// ServerByCountryCode returns the discovery server by the country code
// An error is returned if and only if nil is returned for the server.
func (discovery *Discovery) ServerByCountryCode(countryCode string) (*Server, error) {
+ discovery.mu.RLock()
+ defer discovery.mu.RUnlock()
for _, srv := range discovery.ServerList.List {
if srv.CountryCode == countryCode && srv.Type == "secure_internet" {
return &srv, nil
@@ -261,8 +274,10 @@ func (discovery *Discovery) orgByID(orgID string) (*Organization, error) {
// - The secure internet server itself
// An error is returned if and only if nil is returned for the organization.
func (discovery *Discovery) SecureHomeArgs(orgID string) (*Organization, *Server, error) {
+ discovery.mu.RLock()
org, err := discovery.orgByID(orgID)
if err != nil {
+ discovery.mu.RUnlock()
discovery.MarkOrganizationsExpired()
return nil, nil, err
}
@@ -270,17 +285,19 @@ func (discovery *Discovery) SecureHomeArgs(orgID string) (*Organization, *Server
// Get a server with the base url
srv, err := discovery.ServerByURL(org.SecureInternetHome, "secure_internet")
if err != nil {
+ discovery.mu.RUnlock()
discovery.MarkOrganizationsExpired()
return nil, nil, err
}
+ discovery.mu.RUnlock()
return org, srv, nil
}
-// DetermineServersUpdate returns whether or not the discovery servers should be updated by contacting the discovery server
+// determineServersUpdate returns whether or not the discovery servers should be updated by contacting the discovery server
// https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md
// - [Implemented] The application MUST always fetch the server_list.json at application start.
// - The application MAY refresh the server_list.json periodically, e.g. once every hour.
-func (discovery *Discovery) DetermineServersUpdate() bool {
+func (discovery *Discovery) determineServersUpdate() bool {
// No servers, we should update
if discovery.ServerList.Timestamp.IsZero() {
return true
@@ -305,9 +322,14 @@ func (discovery *Discovery) cachedOrgs() *Organizations {
// If there was an error, a cached copy is returned if available.
// cache is set to true if there should be no network call done
func (discovery *Discovery) Organizations(ctx context.Context, cache bool) (*Organizations, bool, error) {
- if cache || !discovery.DetermineOrganizationsUpdate() {
+ discovery.mu.RLock()
+ if cache || !discovery.determineOrganizationsUpdate() {
+ discovery.mu.RUnlock()
return discovery.cachedOrgs(), false, nil
}
+ discovery.mu.RUnlock()
+ discovery.mu.Lock()
+ defer discovery.mu.Unlock()
file := "organization_list.json"
var jsonDecode Organizations
update, err := discovery.file(ctx, file, discovery.OrganizationList.Version, discovery.OrganizationList.UpdateHeader, &jsonDecode)
@@ -349,9 +371,14 @@ func (discovery *Discovery) cachedServers() *Servers {
// If there was an error, a cached copy is returned if available.
// cache is set to true if there should be no network call done
func (discovery *Discovery) Servers(ctx context.Context, cache bool) (*Servers, bool, error) {
- if cache || !discovery.DetermineServersUpdate() {
+ discovery.mu.RLock()
+ if cache || !discovery.determineServersUpdate() {
+ discovery.mu.RUnlock()
return discovery.cachedServers(), false, nil
}
+ discovery.mu.RUnlock()
+ discovery.mu.Lock()
+ defer discovery.mu.Unlock()
file := "server_list.json"
var jsonDecode Servers
update, err := discovery.file(ctx, file, discovery.ServerList.Version, discovery.ServerList.UpdateHeader, &jsonDecode)
@@ -397,29 +424,15 @@ func (discovery *Discovery) UpdateOrganizations(other Organizations) {
}
}
-// Copy creates a deep-copy for the discovery struct
-// It does this by marshalling and unmarshalling it as JSON
-func (discovery *Discovery) Copy() (Discovery, error) {
- var dest Discovery
- b, err := json.Marshal(discovery)
- if err != nil {
- return dest, err
- }
-
- err = json.Unmarshal(b, &dest)
- if err != nil {
- return dest, err
- }
-
- return dest, nil
-}
-
// Fill makes sure that the cache is filled with the embedded discovery
func (discovery *Discovery) Fill() error {
if !HasCache {
return nil
}
+ discovery.mu.Lock()
+ defer discovery.mu.Unlock()
+
var err error
var es Servers
err = errors.Join(err, json.Unmarshal(eServers, &es))
diff --git a/internal/discovery/manager.go b/internal/discovery/manager.go
deleted file mode 100644
index 4fb4f8e..0000000
--- a/internal/discovery/manager.go
+++ /dev/null
@@ -1,92 +0,0 @@
-package discovery
-
-import (
- "context"
- "log/slog"
- "sync"
-)
-
-// Manager is the discovery struct that is cached
-// with some bookkeeping to avoid race conditions
-type Manager struct {
- disco *Discovery
-
- cancel context.CancelFunc
- mu sync.RWMutex
- wait sync.WaitGroup
-}
-
-// NewManager creates a new Discovery manager
-func NewManager(disco *Discovery) *Manager {
- return &Manager{disco: disco}
-}
-
-func (m *Manager) lock(write bool) {
- if write {
- m.mu.Lock()
- return
- }
- m.mu.RLock()
-}
-
-func (m *Manager) unlock(write bool) {
- if write {
- m.mu.Unlock()
- return
- }
- m.mu.RUnlock()
-}
-
-// Discovery gets the cached discovery
-// `write` is true if discovery will be written to
-func (m *Manager) Discovery(write bool) (*Discovery, func()) {
- if write {
- m.wait.Wait()
- }
- m.lock(write)
- return m.disco, func() {
- m.unlock(write)
- }
-}
-
-// Cancel aborts the running discovery startup process
-func (m *Manager) Cancel() {
- if m.cancel != nil {
- m.cancel()
- }
- m.wait.Wait()
-}
-
-// Startup handles the discovery process in the background
-// It's called Startup because it's called when the lib is initialised
-func (m *Manager) Startup(ctx context.Context, cb func()) {
- ctx, cancel := context.WithCancel(ctx)
- m.cancel = cancel
- m.wait.Add(1)
- go func() {
- m.lock(false)
- discoCopy, err := m.disco.Copy()
- if err != nil {
- slog.Warn("failed to clone discovery", "error", err)
- return
- }
- m.unlock(false)
- // we already log the warning
- discoCopy.Servers(ctx, false) //nolint:errcheck
-
- m.lock(true)
- m.disco.UpdateServers(discoCopy.ServerList)
- m.unlock(true)
- m.wait.Done()
-
- select {
- case <-ctx.Done():
- return
- default:
- if cb == nil {
- return
- }
- cb()
- }
- }()
-}
diff --git a/internal/server/institute.go b/internal/server/institute.go
index 348e895..0a7ef3a 100644
--- a/internal/server/institute.go
+++ b/internal/server/institute.go
@@ -17,15 +17,12 @@ import (
// `disco` are the discovery servers
// `id` is the identifier for the server, the base url
// `ot` specifies specifies the start time OAuth was already triggered
-func (s *Servers) AddInstitute(ctx context.Context, discom *discovery.Manager, id string, ot *int64) error {
+func (s *Servers) AddInstitute(ctx context.Context, disco *discovery.Discovery, id string, ot *int64) error {
// This is basically done to double check if the server is part of the institute access section of disco
- disco, release := discom.Discovery(false)
dsrv, err := disco.ServerByURL(id, "institute_access")
if err != nil {
- release()
return err
}
- release()
sd := api.ServerData{
ID: dsrv.BaseURL,
@@ -70,15 +67,12 @@ func (s *Servers) AddInstitute(ctx context.Context, discom *discovery.Manager, i
// `disco` are the discovery servers
// `tok` are the tokens such that we do not have to trigger auth
// `disableAuth` is true when auth should never be triggered
-func (s *Servers) GetInstitute(ctx context.Context, id string, discom *discovery.Manager, tok *eduoauth.Token, disableAuth bool) (*Server, error) {
- disco, release := discom.Discovery(false)
+func (s *Servers) GetInstitute(ctx context.Context, id string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) {
// This is basically done to double check if the server is part of the institute access section of disco
dsrv, err := disco.ServerByURL(id, "institute_access")
if err != nil {
- release()
return nil, err
}
- release()
// Get the server from the config
_, err = s.config.GetServer(dsrv.BaseURL, server.TypeInstituteAccess)
diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go
index 69b1e97..18a5e78 100644
--- a/internal/server/secureinternet.go
+++ b/internal/server/secureinternet.go
@@ -45,17 +45,14 @@ func ReplaceWAYF(template string, authURL string, orgID string) string {
// `disco` are the discovery servers
// `orgID` is the organiztaion ID
// `ot` specifies specifies the start time OAuth was already triggered
-func (s *Servers) AddSecure(ctx context.Context, discom *discovery.Manager, orgID string, ot *int64) error {
+func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, orgID string, ot *int64) error {
if s.config.HasSecureInternet() {
return errors.New("a secure internet server already exists")
}
- disco, release := discom.Discovery(false)
dorg, dsrv, err := disco.SecureHomeArgs(orgID)
if err != nil {
- release()
return err
}
- release()
sd := api.ServerData{
ID: dorg.OrgID,
@@ -63,13 +60,11 @@ func (s *Servers) AddSecure(ctx context.Context, discom *discovery.Manager, orgI
BaseWK: dsrv.BaseURL,
BaseAuthWK: dsrv.BaseURL,
ProcessAuth: func(ctx context.Context, url string) (string, error) {
- newd, release := discom.Discovery(true)
- defer release()
// the only thing we can do is log warn
// this is already done in the functions
- newd.Servers(ctx, false) //nolint:errcheck
- newd.Organizations(ctx, false) //nolint:errcheck
- updorg, updsrv, err := newd.SecureHomeArgs(orgID)
+ disco.Servers(ctx, false) //nolint:errcheck
+ disco.Organizations(ctx, false) //nolint:errcheck
+ updorg, updsrv, err := disco.SecureHomeArgs(orgID)
if err != nil {
return "", err
}
@@ -115,25 +110,21 @@ func (s *Servers) AddSecure(ctx context.Context, discom *discovery.Manager, orgI
// `disco` are the discovery servers
// `tok` are the tokens such that the server can be found without triggering auth
// `disableAuth` is set to true when authorization should not be triggered
-func (s *Servers) GetSecure(ctx context.Context, orgID string, discom *discovery.Manager, tok *eduoauth.Token, disableAuth bool) (*Server, error) {
+func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery.Discovery, tok *eduoauth.Token, disableAuth bool) (*Server, error) {
srv, err := s.config.GetServer(orgID, server.TypeSecureInternet)
if err != nil {
return nil, err
}
- disco, release := discom.Discovery(false)
dorg, dhome, err := disco.SecureHomeArgs(orgID)
if err != nil {
- release()
return nil, err
}
dloc, err := disco.ServerByCountryCode(srv.CountryCode)
if err != nil {
- release()
return nil, err
}
- release()
sd := api.ServerData{
ID: dorg.OrgID,
@@ -141,15 +132,13 @@ func (s *Servers) GetSecure(ctx context.Context, orgID string, discom *discovery
BaseWK: dloc.BaseURL,
BaseAuthWK: dhome.BaseURL,
ProcessAuth: func(ctx context.Context, url string) (string, error) {
- newd, release := discom.Discovery(true)
- defer release()
// the only thing we can do is log warn
// this is already done in the functions
- newd.MarkServersExpired()
- newd.Servers(ctx, false) //nolint:errcheck
- newd.MarkOrganizationsExpired()
- newd.Organizations(ctx, false) //nolint:errcheck
- updorg, updsrv, err := newd.SecureHomeArgs(orgID)
+ disco.MarkServersExpired()
+ disco.Servers(ctx, false) //nolint:errcheck
+ disco.MarkOrganizationsExpired()
+ disco.Organizations(ctx, false) //nolint:errcheck
+ updorg, updsrv, err := disco.SecureHomeArgs(orgID)
if err != nil {
return "", err
}
diff --git a/internal/server/servers.go b/internal/server/servers.go
index d1aed97..60b9277 100644
--- a/internal/server/servers.go
+++ b/internal/server/servers.go
@@ -54,12 +54,12 @@ type CurrentServer struct {
}
// ServerWithCallbacks gets the current server as a server struct and triggers callbacks as needed
-func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, discom *discovery.Manager, tokens *eduoauth.Token, disableAuth bool) (*Server, error) {
+func (cs *CurrentServer) ServerWithCallbacks(ctx context.Context, disco *discovery.Discovery, tokens *eduoauth.Token, disableAuth bool) (*Server, error) {
switch cs.Key.T {
case srvtypes.TypeInstituteAccess:
- return cs.srvs.GetInstitute(ctx, cs.Key.ID, discom, tokens, disableAuth)
+ return cs.srvs.GetInstitute(ctx, cs.Key.ID, disco, tokens, disableAuth)
case srvtypes.TypeSecureInternet:
- return cs.srvs.GetSecure(ctx, cs.Key.ID, discom, tokens, disableAuth)
+ return cs.srvs.GetSecure(ctx, cs.Key.ID, disco, tokens, disableAuth)
case srvtypes.TypeCustom:
return cs.srvs.GetCustom(ctx, cs.Key.ID, tokens, disableAuth)
default:
@@ -89,9 +89,7 @@ func (s *Servers) CurrentServer() (*CurrentServer, error) {
}
// PublicCurrent gets the current server into a type that we can return to the client
-func (s *Servers) PublicCurrent(discom *discovery.Manager) (*srvtypes.Current, error) {
- disco, release := discom.Discovery(false)
- defer release()
+func (s *Servers) PublicCurrent(disco *discovery.Discovery) (*srvtypes.Current, error) {
return s.config.PublicCurrent(disco)
}