summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2023-09-06 15:55:26 +0200
committerJeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com>2023-09-25 09:43:37 +0200
commit9697ea01b79cde6c8901d7853dc0b414acf84fa7 (patch)
tree2827e2328de82a59fb840905fbfe295342f3d97f /internal/server
parent2e9dbcb863bf72239a80c7c33f6808d24c3ac69e (diff)
Server: Have separate implementations for refreshing endpoints
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/institute/institute.go22
-rw-r--r--internal/server/secure/secure.go37
-rw-r--r--internal/server/server.go25
3 files changed, 62 insertions, 22 deletions
diff --git a/internal/server/institute/institute.go b/internal/server/institute/institute.go
index e0a52b7..46977ac 100644
--- a/internal/server/institute/institute.go
+++ b/internal/server/institute/institute.go
@@ -3,6 +3,7 @@ package institute
import (
"context"
+ "github.com/eduvpn/eduvpn-common/internal/discovery"
"github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/server/api"
"github.com/eduvpn/eduvpn-common/internal/server/base"
@@ -98,6 +99,27 @@ func (s *Server) NeedsLocation() bool {
return false
}
+func (s *Server) RefreshEndpoints(ctx context.Context, _ *discovery.Discovery) error {
+ // Re-initialize the endpoints
+ b, err := s.Base()
+ if err != nil {
+ return err
+ }
+
+ err = api.Endpoints(ctx, b)
+ if err != nil {
+ return err
+ }
+
+ // update OAuth
+ auth := s.OAuth()
+ if auth != nil {
+ auth.BaseAuthorizationURL = b.Endpoints.API.V3.Authorization
+ auth.TokenURL = b.Endpoints.API.V3.Token
+ }
+ return nil
+}
+
func (s *Server) Public() (interface{}, error) {
return &server.Server{
DisplayName: s.Basic.DisplayName,
diff --git a/internal/server/secure/secure.go b/internal/server/secure/secure.go
index d25bf02..c60c38e 100644
--- a/internal/server/secure/secure.go
+++ b/internal/server/secure/secure.go
@@ -2,6 +2,7 @@ package secure
import (
"context"
+ "github.com/eduvpn/eduvpn-common/internal/discovery"
"github.com/eduvpn/eduvpn-common/internal/oauth"
"github.com/eduvpn/eduvpn-common/internal/server/api"
"github.com/eduvpn/eduvpn-common/internal/server/base"
@@ -58,6 +59,42 @@ func (s *Server) NeedsLocation() bool {
return false
}
+func (s *Server) RefreshEndpoints(ctx context.Context, disco *discovery.Discovery) error {
+ // update OAuth for home server
+ auth := s.OAuth()
+ if auth != nil && s.HomeOrganizationID != "" {
+ _, srv, err := disco.SecureHomeArgs(s.HomeOrganizationID)
+ if err != nil {
+ return err
+ }
+ if hb, ok := s.BaseMap[srv.CountryCode]; ok && hb != nil {
+ err := api.Endpoints(ctx, hb)
+ if err != nil {
+ return err
+ }
+ auth.BaseAuthorizationURL = hb.Endpoints.API.V3.Authorization
+ auth.TokenURL = hb.Endpoints.API.V3.Token
+ }
+ // already updated, return
+ if srv.CountryCode == s.CurrentLocation {
+ return nil
+ }
+ }
+
+ // refresh the current location endpoints
+ // Re-initialize the endpoints
+ b, err := s.Base()
+ if err != nil {
+ return err
+ }
+
+ err = api.Endpoints(ctx, b)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
func (s *Server) addLocation(ctx context.Context, locSrv *discotypes.Server) (*base.Base, error) {
// Initialize the base map if it is non-nil
if s.BaseMap == nil {
diff --git a/internal/server/server.go b/internal/server/server.go
index e8b046f..b6f3b30 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -31,6 +31,9 @@ type Server interface {
// Public returns the representation that will be passed over the CGO barrier
Public() (interface{}, error)
+
+ // RefreshEndpoints refreshes the endpoints for the server
+ RefreshEndpoints(context.Context, *discovery.Discovery) error
}
// Name gets the name for the server and falls back to a default of "Unknown Server"
@@ -227,28 +230,6 @@ func HasValidProfile(ctx context.Context, srv Server, wireguardSupport bool) (bo
return true, nil
}
-func RefreshEndpoints(ctx context.Context, srv Server) error {
- // Get the base struct
- b, err := srv.Base()
- if err != nil {
- return err
- }
-
- // update the base struct
- err = api.Endpoints(ctx, b)
- if err != nil {
- return err
- }
-
- // update OAuth
- auth := srv.OAuth()
- if auth != nil {
- auth.BaseAuthorizationURL = b.Endpoints.API.V3.Authorization
- auth.TokenURL = b.Endpoints.API.V3.Token
- }
- return nil
-}
-
func Config(ctx context.Context, server Server, wireguardSupport bool, preferTCP bool) (*ConfigData, error) {
p, err := CurrentProfile(server)
if err != nil {