diff options
| -rw-r--r-- | client/server.go | 11 | ||||
| -rw-r--r-- | internal/server/instituteaccess.go | 22 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 37 | ||||
| -rw-r--r-- | internal/server/server.go | 25 |
4 files changed, 73 insertions, 22 deletions
diff --git a/client/server.go b/client/server.go index 9d3b3b4..283c531 100644 --- a/client/server.go +++ b/client/server.go @@ -67,7 +67,7 @@ func (c *Client) getConfig(srv server.Server, preferTCP bool, t oauth.Token) (*C // Refresh the server endpoints // This is the best effort - err := server.RefreshEndpoints(srv) + err := srv.RefreshEndpoints(&c.Discovery) if err != nil { log.Logger.Warningf("failed to refresh server endpoints: %v", err) } @@ -103,6 +103,10 @@ func (c *Client) Cleanup(ct oauth.Token) error { c.logError(err) return err } + err = srv.RefreshEndpoints(&c.Discovery) + if err != nil { + log.Logger.Warningf("failed to refresh server endpoints: %v", err) + } // If we need to relogin, update tokens if server.NeedsRelogin(srv) { @@ -552,6 +556,11 @@ func (c *Client) RenewSession() (err error) { return err } + err = srv.RefreshEndpoints(&c.Discovery) + if err != nil { + log.Logger.Warningf("failed to refresh server endpoints: %v", err) + } + // The server has not been chosen yet, this means that we want to manually renew if c.FSM.InState(StateNoServer) { c.FSM.GoTransition(StateChosenServer) diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index 050e721..a51409f 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -1,6 +1,7 @@ package server import ( + "github.com/eduvpn/eduvpn-common/internal/discovery" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/go-errors/errors" ) @@ -72,6 +73,27 @@ func (ias *InstituteAccessServer) OAuth() *oauth.OAuth { return &ias.Auth } +func (ias *InstituteAccessServer) RefreshEndpoints(_ *discovery.Discovery) error { + // Re-initialize the endpoints + b, err := ias.Base() + if err != nil { + return err + } + + err = b.InitializeEndpoints() + if err != nil { + return err + } + + // update OAuth + auth := ias.OAuth() + if auth != nil { + auth.BaseAuthorizationURL = b.Endpoints.API.V3.Authorization + auth.TokenURL = b.Endpoints.API.V3.Token + } + return nil +} + func (ias *InstituteAccessServer) init( url string, name map[string]string, diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 9b1d394..3c40253 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -1,6 +1,7 @@ package server import ( + "github.com/eduvpn/eduvpn-common/internal/discovery" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/util" "github.com/eduvpn/eduvpn-common/types" @@ -136,3 +137,39 @@ func (s *SecureInternetHomeServer) init( s.Auth.Init(b.URL, b.Endpoints.API.V3.Authorization, b.Endpoints.API.V3.Token) return nil } + +func (s *SecureInternetHomeServer) RefreshEndpoints(disco *discovery.Discovery) error { + // update OAuth for home server + auth := s.OAuth() + if auth != nil && s.HomeOrganizationID != "" { + _, srv, err := disco.SecureHomeArgs(s.HomeOrganizationID) + if err != nil { + return err + } + if hb, ok := s.BaseMap[srv.CountryCode]; ok && hb != nil { + err := hb.InitializeEndpoints() + if err != nil { + return err + } + auth.BaseAuthorizationURL = hb.Endpoints.API.V3.Authorization + auth.TokenURL = hb.Endpoints.API.V3.Token + } + // already updated, return + if srv.CountryCode == s.CurrentLocation { + return nil + } + } + + // refresh the current location endpoints + // Re-initialize the endpoints + b, err := s.Base() + if err != nil { + return err + } + + err = b.InitializeEndpoints() + if err != nil { + return err + } + return nil +} diff --git a/internal/server/server.go b/internal/server/server.go index f62b882..775095c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -4,6 +4,7 @@ import ( "os" "time" + "github.com/eduvpn/eduvpn-common/internal/discovery" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/wireguard" "github.com/go-errors/errors" @@ -25,6 +26,9 @@ type Server interface { // Base returns the server base Base() (*Base, error) + + // RefreshEndpoints + RefreshEndpoints(*discovery.Discovery) error } type EndpointList struct { @@ -251,27 +255,6 @@ func HasValidProfile(srv Server, wireguardSupport bool) (bool, error) { return true, nil } -func RefreshEndpoints(srv Server) error { - // Re-initialize the endpoints - b, err := srv.Base() - if err != nil { - return err - } - - err = b.InitializeEndpoints() - if err != nil { - return err - } - - // update OAuth - auth := srv.OAuth() - if auth != nil { - auth.BaseAuthorizationURL = b.Endpoints.API.V3.Authorization - auth.TokenURL = b.Endpoints.API.V3.Token - } - return nil -} - func Config(server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) { p, err := CurrentProfile(server) if err != nil { |
