diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-11-10 10:00:46 +0100 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-11-10 10:02:32 +0100 |
| commit | 6df9ea2134103d02ecbe73e945b5df5fb7b131a1 (patch) | |
| tree | dc6fb435d049fa273b344baae4f0394ad95c317e /internal/server | |
| parent | f87b8d988f25c112f457baa24609a44ba85d4f34 (diff) | |
Server: Re-initialize endpoints when getting a config
Diffstat (limited to 'internal/server')
| -rw-r--r-- | internal/server/common.go | 22 | ||||
| -rw-r--r-- | internal/server/instituteaccess.go | 6 | ||||
| -rw-r--r-- | internal/server/secureinternet.go | 3 |
3 files changed, 26 insertions, 5 deletions
diff --git a/internal/server/common.go b/internal/server/common.go index 443a925..1b745cb 100644 --- a/internal/server/common.go +++ b/internal/server/common.go @@ -328,6 +328,16 @@ func getCurrentProfile(server Server) (*ServerProfile, error) { ) } +func (base *ServerBase) InitializeEndpoints() error { + errorMessage := "failed initializing endpoints" + endpoints, endpointsErr := APIGetEndpoints(base.URL) + if endpointsErr != nil { + return types.NewWrappedError(errorMessage, endpointsErr) + } + base.Endpoints = *endpoints + return nil +} + func (base *ServerBase) GetValidProfiles(clientSupportsWireguard bool) ServerProfileInfo { var validProfiles []ServerProfile for _, profile := range base.Profiles.Info.ProfileList { @@ -466,6 +476,18 @@ func HasValidProfile(server Server, clientSupportsWireguard bool) (bool, error) func GetConfig(server Server, clientSupportsWireguard bool, preferTCP bool) (string, string, error) { errorMessage := "failed getting an OpenVPN/WireGuard configuration" + // Re-initialize the endpoints + // TODO: Make this a warning instead? + base, baseErr := server.GetBase() + if baseErr != nil { + return "", "", types.NewWrappedError(errorMessage, baseErr) + } + + endpointsErr := base.InitializeEndpoints() + if endpointsErr != nil { + return "", "", types.NewWrappedError(errorMessage, endpointsErr) + } + profile, profileErr := getCurrentProfile(server) if profileErr != nil { return "", "", types.NewWrappedError(errorMessage, profileErr) diff --git a/internal/server/instituteaccess.go b/internal/server/instituteaccess.go index ed0211b..045535a 100644 --- a/internal/server/instituteaccess.go +++ b/internal/server/instituteaccess.go @@ -89,11 +89,11 @@ func (institute *InstituteAccessServer) init( institute.Base.DisplayName = displayName institute.Base.SupportContact = supportContact institute.Base.Type = serverType - endpoints, endpointsErr := APIGetEndpoints(url) + endpointsErr := institute.Base.InitializeEndpoints() if endpointsErr != nil { return types.NewWrappedError(errorMessage, endpointsErr) } - institute.OAuth.Init(url, endpoints.API.V3.Authorization, endpoints.API.V3.Token) - institute.Base.Endpoints = *endpoints + API := institute.Base.Endpoints.API.V3 + institute.OAuth.Init(url, API.Authorization, API.Token) return nil } diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index b3e2615..cfe9ea1 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -111,11 +111,10 @@ func (secure *SecureInternetHomeServer) addLocation( base.DisplayName = secure.DisplayName base.SupportContact = locationServer.SupportContact base.Type = "secure_internet" - endpoints, endpointsErr := APIGetEndpoints(locationServer.BaseURL) + endpointsErr := base.InitializeEndpoints() if endpointsErr != nil { return nil, types.NewWrappedError(errorMessage, endpointsErr) } - base.Endpoints = *endpoints } // Ensure it is in the map |
