diff options
| -rw-r--r-- | client/client.go | 2 | ||||
| -rw-r--r-- | internal/server/institute/institute.go | 22 | ||||
| -rw-r--r-- | internal/server/secure/secure.go | 37 | ||||
| -rw-r--r-- | internal/server/server.go | 25 |
4 files changed, 63 insertions, 23 deletions
diff --git a/client/client.go b/client/client.go index f8863eb..eab704a 100644 --- a/client/client.go +++ b/client/client.go @@ -623,7 +623,7 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes. return nil, err } // refresh the server endpoints - err = server.RefreshEndpoints(ck.Context(), srv) + err = srv.RefreshEndpoints(ck.Context(), &c.Discovery) // If we get a canceled error, return that, otherwise just log the error if err != nil { diff --git a/internal/server/institute/institute.go b/internal/server/institute/institute.go index e0a52b7..46977ac 100644 --- a/internal/server/institute/institute.go +++ b/internal/server/institute/institute.go @@ -3,6 +3,7 @@ package institute import ( "context" + "github.com/eduvpn/eduvpn-common/internal/discovery" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server/api" "github.com/eduvpn/eduvpn-common/internal/server/base" @@ -98,6 +99,27 @@ func (s *Server) NeedsLocation() bool { return false } +func (s *Server) RefreshEndpoints(ctx context.Context, _ *discovery.Discovery) error { + // Re-initialize the endpoints + b, err := s.Base() + if err != nil { + return err + } + + err = api.Endpoints(ctx, b) + if err != nil { + return err + } + + // update OAuth + auth := s.OAuth() + if auth != nil { + auth.BaseAuthorizationURL = b.Endpoints.API.V3.Authorization + auth.TokenURL = b.Endpoints.API.V3.Token + } + return nil +} + func (s *Server) Public() (interface{}, error) { return &server.Server{ DisplayName: s.Basic.DisplayName, diff --git a/internal/server/secure/secure.go b/internal/server/secure/secure.go index d25bf02..c60c38e 100644 --- a/internal/server/secure/secure.go +++ b/internal/server/secure/secure.go @@ -2,6 +2,7 @@ package secure import ( "context" + "github.com/eduvpn/eduvpn-common/internal/discovery" "github.com/eduvpn/eduvpn-common/internal/oauth" "github.com/eduvpn/eduvpn-common/internal/server/api" "github.com/eduvpn/eduvpn-common/internal/server/base" @@ -58,6 +59,42 @@ func (s *Server) NeedsLocation() bool { return false } +func (s *Server) RefreshEndpoints(ctx context.Context, disco *discovery.Discovery) error { + // update OAuth for home server + auth := s.OAuth() + if auth != nil && s.HomeOrganizationID != "" { + _, srv, err := disco.SecureHomeArgs(s.HomeOrganizationID) + if err != nil { + return err + } + if hb, ok := s.BaseMap[srv.CountryCode]; ok && hb != nil { + err := api.Endpoints(ctx, hb) + if err != nil { + return err + } + auth.BaseAuthorizationURL = hb.Endpoints.API.V3.Authorization + auth.TokenURL = hb.Endpoints.API.V3.Token + } + // already updated, return + if srv.CountryCode == s.CurrentLocation { + return nil + } + } + + // refresh the current location endpoints + // Re-initialize the endpoints + b, err := s.Base() + if err != nil { + return err + } + + err = api.Endpoints(ctx, b) + if err != nil { + return err + } + return nil +} + func (s *Server) addLocation(ctx context.Context, locSrv *discotypes.Server) (*base.Base, error) { // Initialize the base map if it is non-nil if s.BaseMap == nil { diff --git a/internal/server/server.go b/internal/server/server.go index e8b046f..b6f3b30 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -31,6 +31,9 @@ type Server interface { // Public returns the representation that will be passed over the CGO barrier Public() (interface{}, error) + + // RefreshEndpoints refreshes the endpoints for the server + RefreshEndpoints(context.Context, *discovery.Discovery) error } // Name gets the name for the server and falls back to a default of "Unknown Server" @@ -227,28 +230,6 @@ func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bo return true, nil } -func RefreshEndpoints(ctx context.Context, srv Server) error { - // Get the base struct - b, err := srv.Base() - if err != nil { - return err - } - - // update the base struct - err = api.Endpoints(ctx, b) - if err != nil { - return err - } - - // update OAuth - auth := srv.OAuth() - if auth != nil { - auth.BaseAuthorizationURL = b.Endpoints.API.V3.Authorization - auth.TokenURL = b.Endpoints.API.V3.Token - } - return nil -} - func Config(ctx context.Context, server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) { p, err := CurrentProfile(server) if err != nil { |
