diff options
| -rw-r--r-- | client/client.go | 80 | ||||
| -rw-r--r-- | client/discovery.go | 80 | ||||
| -rw-r--r-- | exports/exports.go | 27 | ||||
| -rw-r--r-- | internal/api/api_test.go | 2 | ||||
| -rw-r--r-- | internal/discovery/discovery.go | 21 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 12 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/loader.py | 2 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/main.py | 13 | ||||
| -rw-r--r-- | wrappers/python/eduvpn_common/types.py | 1 |
9 files changed, 203 insertions, 35 deletions
diff --git a/client/client.go b/client/client.go index 9b51d82..0fd4230 100644 --- a/client/client.go +++ b/client/client.go @@ -53,20 +53,12 @@ type Client struct { proxy Proxy mu sync.Mutex + + discoMan DiscoManager } -// MarkOrganizationsExpired marks the discovery organization list as expired -// it's a no-op if the type `t` is not secure internet -// or if discovery is nil -func (c *Client) MarkOrganizationsExpired(t srvtypes.Type) { - if t != srvtypes.TypeSecureInternet { - return - } - disco := c.cfg.Discovery() - if disco == nil { - return - } - disco.MarkOrganizationsExpired() +func (c *Client) DiscoveryStartup(cb func()) { + c.discoMan.Startup(context.Background(), cb) } // GettingConfig is defined here to satisfy the server.Callbacks interface @@ -167,12 +159,19 @@ func New(name string, version string, directory string, stateCallback func(FSMSt // set the servers c.Servers = server.NewServers(c.Name, c, c.cfg.V2) - // the first fetch for the servers should be fresh - c.cfg.Discovery().MarkServersExpired() + c.discoMan = DiscoManager{disco: c.cfg.Discovery()} + + if !c.hasDiscovery() { + return c, nil + } + disco, release := c.discoMan.Discovery(true) + defer release() + disco.MarkServersExpired() if !c.cfg.HasSecureInternet() { - c.cfg.Discovery().MarkOrganizationsExpired() + disco.MarkOrganizationsExpired() } + return c, nil } @@ -255,6 +254,9 @@ func (c *Client) Register() error { // Deregister 'deregisters' the client, meaning saving the log file and the config and emptying out the client struct. func (c *Client) Deregister() { + c.discoMan.Cancel() + + _, release := c.discoMan.Discovery(false) // save the config c.TrySave() @@ -266,6 +268,7 @@ func (c *Client) Deregister() { // Close the log file _ = log.Logger.Close() + release() // Empty out the state *c = Client{} @@ -290,8 +293,8 @@ func (c *Client) ExpiryTimes() (*srvtypes.Expiry, error) { }, nil } -func (c *Client) locationCallback(ck *cookie.Cookie, orgID string) error { - locs := c.cfg.Discovery().SecureLocationList() +func (c *Client) locationCallback(ck *cookie.Cookie, disco *discovery.Discovery, orgID string) error { + locs := disco.SecureLocationList() errChan := make(chan error) go func() { err := c.FSM.GoTransitionRequired(StateAskLocation, &srvtypes.RequiredAskTransition{ @@ -342,6 +345,8 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. // If we have failed to add the server, we remove it again // We add the server because we can then obtain it in other callback functions previousState := c.FSM.Current + + var release func() defer func() { // If we must run callbacks, go to the previous state if we're not in it if !ni && !c.FSM.InState(previousState) { @@ -350,6 +355,9 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. if err == nil { c.TrySave() } + if release != nil { + release() + } }() if !ni { @@ -367,14 +375,16 @@ func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes. } } + var disco *discovery.Discovery + disco, release = c.discoMan.Discovery(true) switch _type { case srvtypes.TypeInstituteAccess: - err = c.Servers.AddInstitute(ck.Context(), c.cfg.Discovery(), identifier, ot) + err = c.Servers.AddInstitute(ck.Context(), disco, identifier, ot) if err != nil { return i18nerr.Wrapf(err, "Failed to add an institute access server with URL: '%s'", identifier) } case srvtypes.TypeSecureInternet: - err = c.Servers.AddSecure(ck.Context(), c.cfg.Discovery(), identifier, ot) + err = c.Servers.AddSecure(ck.Context(), disco, identifier, ot) if err != nil { return i18nerr.Wrapf(err, "Failed to add a secure internet server with organisation ID: '%s'", identifier) } @@ -412,6 +422,7 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. return nil, i18nerr.NewInternalf("Getting a non-custom server when the client does not use discovery is not supported, identifier: %s, type: %d", identifier, _type) } + var release func() defer func() { if err == nil { // it could be that we are not in getting config yet if we have just done authorization @@ -422,6 +433,9 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. c.FSM.GoTransition(previousState) //nolint:errcheck } c.TrySave() + if release != nil { + release() + } }() identifier, err = c.convertIdentifier(identifier, _type) @@ -439,7 +453,8 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. } ctx := ck.Context() - disco := c.cfg.Discovery() + var disco *discovery.Discovery + disco, release = c.discoMan.Discovery(true) if _type != srvtypes.TypeCustom { // make sure the servers are fetched fresh _, _, dserverr := disco.Servers(ctx) @@ -462,7 +477,7 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. var cErr *discovery.ErrCountryNotFound if errors.As(err, &cErr) { - err = c.locationCallback(ck, identifier) + err = c.locationCallback(ck, disco, identifier) if err == nil { srv, err = c.Servers.GetSecure(ctx, identifier, disco, tok, startup) } @@ -499,14 +514,21 @@ func (c *Client) RemoveServer(identifier string, _type srvtypes.Type) (err error if err != nil { return i18nerr.WrapInternalf(err, "Failed to remove server: '%s'", identifier) } - c.MarkOrganizationsExpired(_type) + disco, release := c.discoMan.Discovery(true) + defer release() + if _type == srvtypes.TypeSecureInternet { + disco.MarkOrganizationsExpired() + } c.TrySave() return nil } // CurrentServer gets the current server that is configured func (c *Client) CurrentServer() (*srvtypes.Current, error) { - curr, err := c.Servers.PublicCurrent(c.cfg.Discovery()) + // TODO: do clients call this during a write mutex? + disco, release := c.discoMan.Discovery(false) + defer release() + curr, err := c.Servers.PublicCurrent(disco) if err != nil { return nil, i18nerr.WrapInternal(err, "The current server could not be retrieved") } @@ -561,7 +583,9 @@ func (c *Client) Cleanup(ck *cookie.Cookie) error { if err != nil { return i18nerr.WrapInternal(err, "No OAuth tokens were found when cleaning up the connection") } - auth, err := srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), tok, true) + disco, release := c.discoMan.Discovery(true) + defer release() + auth, err := srv.ServerWithCallbacks(ck.Context(), disco, tok, true) if err != nil { return i18nerr.WrapInternal(err, "The server was unable to be retrieved when cleaning up the connection") } @@ -608,7 +632,9 @@ func (c *Client) RenewSession(ck *cookie.Cookie) error { } // getting a server with no tokens means re-authorize - _, err = srv.ServerWithCallbacks(ck.Context(), c.cfg.Discovery(), nil, false) + disco, release := c.discoMan.Discovery(true) + defer release() + _, err = srv.ServerWithCallbacks(ck.Context(), disco, nil, false) if err != nil { return i18nerr.WrapInternal(err, "The server was unable to be retrieved when renewing the session") } @@ -629,6 +655,8 @@ func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readR // ServerList gets the list of servers func (c *Client) ServerList() (*srvtypes.List, error) { - g := c.cfg.V2.PublicList(c.cfg.Discovery()) + disco, release := c.discoMan.Discovery(false) + defer release() + g := c.cfg.V2.PublicList(disco) return g, nil } diff --git a/client/discovery.go b/client/discovery.go index 415be9b..2758678 100644 --- a/client/discovery.go +++ b/client/discovery.go @@ -1,9 +1,13 @@ package client import ( + "context" "sort" "strings" + "sync" + "github.com/eduvpn/eduvpn-common/internal/discovery" + "github.com/eduvpn/eduvpn-common/internal/log" "github.com/eduvpn/eduvpn-common/i18nerr" "github.com/eduvpn/eduvpn-common/types/cookie" discotypes "github.com/eduvpn/eduvpn-common/types/discovery" @@ -24,7 +28,10 @@ func (c *Client) DiscoOrganizations(ck *cookie.Cookie, search string) (*discotyp return nil, i18nerr.NewInternal("Server/organization discovery with this client ID is not supported") } - orgs, fresh, err := c.cfg.Discovery().Organizations(ck.Context()) + disco, release := c.discoMan.Discovery(true) + defer release() + + orgs, fresh, err := disco.Organizations(ck.Context()) if fresh { defer c.TrySave() } @@ -70,7 +77,9 @@ func (c *Client) DiscoServers(ck *cookie.Cookie, search string) (*discotypes.Ser return nil, i18nerr.NewInternal("Server/organization discovery with this client ID is not supported") } - servs, fresh, err := c.cfg.Discovery().Servers(ck.Context()) + disco, release := c.discoMan.Discovery(true) + defer release() + servs, fresh, err := disco.Servers(ck.Context()) if fresh { defer c.TrySave() } @@ -105,3 +114,70 @@ func (c *Client) DiscoServers(ck *cookie.Cookie, search string) (*discotypes.Ser List: retServs, }, err } + +type DiscoManager struct { + disco *discovery.Discovery + + cancel context.CancelFunc + mu sync.RWMutex + wait sync.WaitGroup +} + +func (m *DiscoManager) lock(write bool) { + if write { + m.mu.Lock() + return + } + m.mu.RLock() +} + +func (m *DiscoManager) unlock(write bool) { + if write { + m.mu.Unlock() + return + } + m.mu.RUnlock() +} + +func (m *DiscoManager) Discovery(write bool) (*discovery.Discovery, func()) { + if write { + m.wait.Wait() + } + m.lock(write) + return m.disco, func() { + m.unlock(write) + } +} + +func (m *DiscoManager) Cancel() { + if m.cancel != nil { + m.cancel() + } + m.wait.Wait() +} + +func (m *DiscoManager) Startup(ctx context.Context, cb func()) { + ctx, cancel := context.WithCancel(ctx) + m.cancel = cancel + m.wait.Add(1) + go func() { + defer m.wait.Done() + m.lock(false) + discoCopy, err := m.disco.Copy() + if err != nil { + log.Logger.Warningf("internal error, failed to clone discovery, %v", err) + return + } + m.unlock(false) + // we already log the warning + discoCopy.Servers(ctx) //nolint:errcheck + + m.lock(true) + m.disco.UpdateServers(discoCopy) + m.unlock(true) + + if cb != nil { + cb() + } + }() +} diff --git a/exports/exports.go b/exports/exports.go index 1e496f8..36372b4 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -21,6 +21,7 @@ typedef long long int (*ReadRxBytes)(); typedef int (*StateCB)(int oldstate, int newstate, void* data); +typedef void (*RefreshList)(); typedef void (*TokenGetter)(const char* server_id, int server_type, char* out, size_t len); typedef void (*TokenSetter)(const char* server_id, int server_type, const char* tokens); typedef void (*ProxySetup)(int fd, const char* peer_ips); @@ -34,6 +35,10 @@ static int call_callback(StateCB callback, int oldstate, int newstate, void* dat { return callback(oldstate, newstate, data); } +static void call_refresh_list(RefreshList refresh) +{ + refresh(); +} static void call_token_getter(TokenGetter getter, const char* server_id, int server_type, char* out, size_t len) { getter(server_id, server_type, out, len); @@ -1024,6 +1029,28 @@ func getCookie(c C.uintptr_t) (*cookie.Cookie, error) { return v, nil } + +// DiscoveryStartup does a discovery request in the background +// +// The `refresh` argument is a callback that is called when the refreshing is done +// When this callback is thus called, the app SHOULD refresh the server list of the already configured servers +// This DiscoveryStartup function MUST be called after calling `Register` +// +//export DiscoveryStartup +func DiscoveryStartup(refresh C.RefreshList) *C.char { + state, stateErr := getVPNState() + if stateErr != nil { + return getCError(stateErr) + } + state.DiscoveryStartup(func() { + if refresh == nil { + return + } + C.call_refresh_list(refresh) + }) + return nil +} + // SetTokenHandler sets the token getters and token setters for OAuth // // Because the data that is saved does not contain OAuth tokens for server, the common lib asks and sets the tokens using these callback functions. diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 397dd3c..fcf02e9 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -196,7 +196,7 @@ func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.Ha Type: server.TypeCustom, BaseWK: serv.URL, BaseAuthWK: serv.URL, - ProcessAuth: func(ctx context.Context, in string) (string, error) { + ProcessAuth: func(_ context.Context, in string) (string, error) { return in, nil }, DisableAuthorize: false, diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index 30ca801..231c12a 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -405,3 +405,24 @@ func (discovery *Discovery) Servers(ctx context.Context) (*Servers, bool, error) } return &discovery.ServerList, true, nil } + +func (discovery *Discovery) UpdateServers(other Discovery) { + if other.ServerList.Version >= discovery.ServerList.Version { + discovery.ServerList = other.ServerList + } +} + +func (discovery *Discovery) Copy() (Discovery, error) { + var dest Discovery + b, err := json.Marshal(discovery) + if err != nil { + return dest, err + } + + err = json.Unmarshal(b, &dest) + if err != nil { + return dest, err + } + + return dest, nil +} diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index f167756..990ceb3 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -34,8 +34,10 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org BaseWK: dsrv.BaseURL, BaseAuthWK: dsrv.BaseURL, ProcessAuth: func(ctx context.Context, url string) (string, error) { - disco.Servers(ctx) - disco.Organizations(ctx) + // the only thing we can do is log warn + // this is already done in the functions + disco.Servers(ctx) //nolint:errcheck + disco.Organizations(ctx) //nolint:errcheck updorg, updsrv, err := disco.SecureHomeArgs(orgID) if err != nil { return "", err @@ -104,10 +106,12 @@ func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery. BaseWK: dloc.BaseURL, BaseAuthWK: dhome.BaseURL, ProcessAuth: func(ctx context.Context, url string) (string, error) { + // the only thing we can do is log warn + // this is already done in the functions disco.MarkServersExpired() - disco.Servers(ctx) + disco.Servers(ctx) //nolint:errcheck disco.MarkOrganizationsExpired() - disco.Organizations(ctx) + disco.Organizations(ctx) //nolint:errcheck updorg, updsrv, err := disco.SecureHomeArgs(orgID) if err != nil { return "", err diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py index b74741f..4bfc55f 100644 --- a/wrappers/python/eduvpn_common/loader.py +++ b/wrappers/python/eduvpn_common/loader.py @@ -8,6 +8,7 @@ from eduvpn_common.types import ( ProxyReady, ProxySetup, ReadRxBytes, + RefreshList, TokenGetter, TokenSetter, VPNStateChange, @@ -90,6 +91,7 @@ def initialize_functions(lib: CDLL) -> None: c_void_p, ) lib.RenewSession.argtypes, lib.RenewSession.restype = [c_int], c_void_p + lib.DiscoveryStartup.argtypes, lib.DiscoveryStartup.restype = [RefreshList], c_void_p lib.SetTokenHandler.argtypes, lib.SetTokenHandler.restype = ( [ TokenGetter, diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py index cb81e53..8c556e9 100644 --- a/wrappers/python/eduvpn_common/main.py +++ b/wrappers/python/eduvpn_common/main.py @@ -10,6 +10,7 @@ from eduvpn_common.types import ( ProxyReady, ProxySetup, ReadRxBytes, + RefreshList, TokenGetter, TokenSetter, VPNStateChange, @@ -180,12 +181,14 @@ class EduVPN(object): def get_disco_organizations(self, search="") -> str: orgs, _ = self.go_cookie_function(self.lib.DiscoOrganizations, search) - # TODO: Log error + # We don't log anything here as we want to return a result and we don't want to throw here + # we already log for errors in common return orgs def get_disco_servers(self, search="") -> str: servers, _ = self.go_cookie_function(self.lib.DiscoServers, search) - # TODO: Log error + # We don't log anything here as we want to return a result and we don't want to throw here + # we already log for errors in common return servers def get_servers(self) -> str: @@ -298,6 +301,12 @@ class EduVPN(object): if location_err: forwardError(location_err) + def discovery_startup(self, refresh: RefreshList) -> None: + refresh_err = self.go_function(self.lib.DiscoveryStartup, refresh) + + if refresh_err: + forwardError(refresh_err) + def set_token_handler(self, getter: Callable, setter: Callable) -> None: self.token_setter = setter self.token_getter = getter diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py index 21690fb..716556e 100644 --- a/wrappers/python/eduvpn_common/types.py +++ b/wrappers/python/eduvpn_common/types.py @@ -39,6 +39,7 @@ VPNStateChange = CFUNCTYPE(c_int, c_int, c_int, c_char_p) ProxySetup = CFUNCTYPE(c_void_p, c_int, c_char_p) ProxyReady = CFUNCTYPE(c_void_p) ReadRxBytes = CFUNCTYPE(c_ulonglong) +RefreshList = CFUNCTYPE(c_void_p) TokenGetter = CFUNCTYPE(c_void_p, c_char_p, c_int, POINTER(c_char), c_size_t) TokenSetter = CFUNCTYPE(c_void_p, c_char_p, c_int, c_char_p) |
