summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-07-05 13:17:24 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-07-05 13:17:24 +0200
commit1865b016d0cca74cd3703db5a3b4217917988dec (patch)
tree3da84dbc4f1ad49221c25fb83f402d27deb34138 /internal
parente39b9a8a405fa8e5f73c32bb03a3f349f7f9f92d (diff)
Refactor: Handling of different servers and identifiers
- Uses OrgID for Secure Internet and gets the data from discovery - Uses URL for Institute/Custom and gets the data from discovery - Implements SKIP WAYF as we now have the needed data - Implements an initial change location with a default location (NL right now)
Diffstat (limited to 'internal')
-rw-r--r--internal/discovery/discovery.go145
-rw-r--r--internal/oauth/oauth.go13
-rw-r--r--internal/server/server.go135
-rw-r--r--internal/types/server.go42
-rw-r--r--internal/util/util.go23
5 files changed, 274 insertions, 84 deletions
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index ac3bf57..c61469d 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -12,27 +12,15 @@ import (
"github.com/jwijenbergh/eduvpn-common/internal/verify"
)
-type OrganizationList struct {
- JSON json.RawMessage `json:"organization_list"`
- Version uint64 `json:"v"`
- Timestamp int64 `json:"-"`
-}
-
-type ServersList struct {
- JSON json.RawMessage `json:"server_list"`
- Version uint64 `json:"v"`
- Timestamp int64 `json:"-"`
-}
-
type Discovery struct {
- Organizations OrganizationList
- Servers ServersList
+ Organizations types.DiscoveryOrganizations
+ Servers types.DiscoveryServers
FSM *fsm.FSM
Logger *log.FileLogger
}
// Helper function that gets a disco json
-func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) error {
+func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) (string, error) {
errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile)
// Get json data
discoURL := "https://disco.eduvpn.org/v2/"
@@ -40,7 +28,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
_, fileBody, fileErr := http.HTTPGet(fileURL)
if fileErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr}
}
// Get signature
@@ -49,7 +37,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
_, sigBody, sigFileErr := http.HTTPGet(sigURL)
if sigFileErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr}
}
// Verify signature
@@ -58,17 +46,17 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
verifySuccess, verifyErr := verify.Verify(string(sigBody), fileBody, jsonFile, previousVersion, forcePrehash)
if !verifySuccess || verifyErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr}
}
// Parse JSON to extract version and list
jsonErr := json.Unmarshal(fileBody, structure)
if jsonErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
}
- return nil
+ return string(fileBody), nil
}
func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) {
@@ -82,7 +70,63 @@ func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) {
// - [TODO] when the user tries to add new server AND the user did NOT yet choose an organization before;
// - [TODO] when the authorization for the server associated with an already chosen organization is triggered, e.g. after expiry or revocation.
func (discovery *Discovery) DetermineOrganizationsUpdate() bool {
- return string(discovery.Organizations.JSON) == ""
+ return discovery.Organizations.Timestamp == 0
+}
+
+func (discovery *Discovery) GetSecureLocationList() []string {
+ var locations []string
+ for _, server := range discovery.Servers.List {
+ if server.Type == "secure_internet" {
+ locations = append(locations, server.CountryCode)
+ }
+ }
+ return locations
+}
+
+func (discovery *Discovery) GetServerByURL(url string, _type string) (*types.DiscoveryServer, error) {
+ for _, server := range discovery.Servers.List {
+ if server.BaseURL == url && server.Type == _type {
+ return &server, nil
+ }
+ }
+ return nil, &types.WrappedErrorMessage{Message: "failed getting server by URL from discovery", Err: &GetServerByURLNotFoundError{URL: url, Type: _type}}
+}
+
+func (discovery *Discovery) GetServerByCountryCode(code string, _type string) (*types.DiscoveryServer, error) {
+ for _, server := range discovery.Servers.List {
+ if server.CountryCode == code && server.Type == _type {
+ return &server, nil
+ }
+ }
+ return nil, &types.WrappedErrorMessage{Message: "failed getting server by country code from discovery", Err: &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type}}
+}
+
+func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganization, error) {
+ for _, organization := range discovery.Organizations.List {
+ if organization.OrgId == orgID {
+ return &organization, nil
+ }
+ }
+ return nil, &types.WrappedErrorMessage{Message: "failed getting Secure Internet Home URL from discovery", Err: &GetOrgByIDNotFoundError{ID: orgID}}
+}
+
+func (discovery *Discovery) GetSecureHomeArgs(orgID string) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) {
+ errorMessage := "failed getting Secure Internet Home arguments from discovery"
+ org, orgErr := discovery.getOrgByID(orgID)
+
+ if orgErr != nil {
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: orgErr}
+ }
+
+ // Get a server with the base url
+ url := org.SecureInternetHome
+
+ server, serverErr := discovery.GetServerByURL(url, "secure_internet")
+
+ if serverErr != nil {
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
+ }
+ return org, server, nil
}
// https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md
@@ -90,7 +134,7 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool {
// - The application MAY refresh the server_list.json periodically, e.g. once every hour.
func (discovery *Discovery) DetermineServersUpdate() bool {
// No servers, we should update
- if string(discovery.Servers.JSON) == "" {
+ if discovery.Servers.Timestamp == 0 {
return true
}
// 1 hour from the last update
@@ -106,29 +150,66 @@ func (discovery *Discovery) DetermineServersUpdate() bool {
// Get the organization list
func (discovery *Discovery) GetOrganizationsList() (string, error) {
if !discovery.DetermineOrganizationsUpdate() {
- return string(discovery.Organizations.JSON), nil
+ return discovery.Organizations.RawString, nil
}
file := "organization_list.json"
- err := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations)
- if err != nil {
+ body, bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations)
+ if bodyErr != nil {
// Return previous with an error
- return string(discovery.Organizations.JSON), &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: err}
+ return discovery.Organizations.RawString, &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: bodyErr}
}
- return string(discovery.Organizations.JSON), nil
+ discovery.Organizations.RawString = body
+ discovery.Organizations.Timestamp = util.GenerateTimeSeconds()
+ return discovery.Organizations.RawString, nil
}
// Get the server list
func (discovery *Discovery) GetServersList() (string, error) {
if !discovery.DetermineServersUpdate() {
- return string(discovery.Servers.JSON), nil
+ return discovery.Servers.RawString, nil
}
file := "server_list.json"
- err := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers)
- if err != nil {
+ body, bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers)
+ if bodyErr != nil {
// Return previous with an error
- return string(discovery.Servers.JSON), &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: err}
+ return discovery.Servers.RawString, &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: bodyErr}
}
// Update servers timestamp
+ discovery.Servers.RawString = body
discovery.Servers.Timestamp = util.GenerateTimeSeconds()
- return string(discovery.Servers.JSON), nil
+ return discovery.Servers.RawString, nil
+}
+
+type GetOrgByIDNotFoundError struct {
+ ID string
+}
+
+func (e GetOrgByIDNotFoundError) Error() string {
+ return fmt.Sprintf("No Secure Internet Home found in organizations with ID %s", e.ID)
+}
+
+type GetServerByURLNotFoundError struct {
+ URL string
+ Type string
+}
+
+func (e GetServerByURLNotFoundError) Error() string {
+ return fmt.Sprintf("No institute access server found in organizations with URL %s and type %s", e.URL, e.Type)
+}
+
+type GetServerByCountryCodeNotFoundError struct {
+ CountryCode string
+ Type string
+}
+
+func (e GetServerByCountryCodeNotFoundError) Error() string {
+ return fmt.Sprintf("No institute access server found in organizations with country code %s and type %s", e.CountryCode, e.Type)
+}
+
+type GetSecureHomeArgsNotFoundError struct {
+ URL string
+}
+
+func (e GetSecureHomeArgsNotFoundError) Error() string {
+ return fmt.Sprintf("No Secure Internet Home found with URL: %s", e.URL)
}
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 824db90..ef1bed4 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -223,11 +223,6 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
}
}
-func (oauth *OAuth) Update(fsm *fsm.FSM, logger *log.FileLogger) {
- oauth.FSM = fsm
- oauth.Logger = logger
-}
-
func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.FSM, logger *log.FileLogger) {
oauth.BaseAuthorizationURL = baseAuthorizationURL
oauth.TokenURL = tokenURL
@@ -236,7 +231,7 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.
}
// Starts the OAuth exchange for eduvpn.
-func (oauth *OAuth) start(name string) error {
+func (oauth *OAuth) start(name string, postprocessAuth func(string) string) error {
errorMessage := "failed starting OAuth exchange"
if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) {
return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()}
@@ -274,7 +269,7 @@ func (oauth *OAuth) start(name string) error {
oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier}
oauth.Session = oauthSession
// Run the state callback in the background so that the user can login while we start the callback server
- oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, authURL, true)
+ oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, postprocessAuth(authURL), true)
return nil
}
@@ -298,9 +293,9 @@ func (oauth *OAuth) Cancel() {
oauth.Session.Server.Shutdown(oauth.Session.Context)
}
-func (oauth *OAuth) Login(name string) error {
+func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) error {
errorMessage := "failed OAuth login"
- authInitializeErr := oauth.start(name)
+ authInitializeErr := oauth.start(name, postprocessAuth)
if authInitializeErr != nil {
return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr}
diff --git a/internal/server/server.go b/internal/server/server.go
index 807bd09..5bc2ea1 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -36,14 +36,16 @@ type InstituteAccessServer struct {
// A secure internet server which has its own OAuth tokens
// It specifies the current location url it is connected to
type SecureInternetHomeServer struct {
+ DisplayName string `json:"display_name"`
OAuth oauth.OAuth `json:"oauth"`
- // The home server has a list of info for each configured server
+ // The home server has a list of info for each configured server location
BaseMap map[string]*ServerBase `json:"base_map"`
- // We have the home url and the current url
- HomeURL string `json:"home_url"`
- CurrentURL string `json:"current_url"`
+ // We have the authorization URL template, the home organization ID and the current location
+ AuthorizationTemplate string `json:"authorization_template"`
+ HomeOrganizationID string `json:"home_organization_id"`
+ CurrentLocation string `json:"current_location"`
}
type InstituteServers struct {
@@ -89,11 +91,11 @@ type Server interface {
// Gets the current OAuth object
GetOAuth() *oauth.OAuth
+ // Get the authorization URL template function
+ GetTemplateAuth() func(string) string
+
// Gets the server base
GetBase() (*ServerBase, error)
-
- // initialize method
- init(url string, fsm *fsm.FSM, logger *log.FileLogger) error
}
// For an institute, we can simply get the OAuth
@@ -105,6 +107,19 @@ func (secure *SecureInternetHomeServer) GetOAuth() *oauth.OAuth {
return &secure.OAuth
}
+
+func (institute *InstituteAccessServer) GetTemplateAuth() (func(string) string) {
+ return func(authURL string) string {
+ return authURL
+ }
+}
+
+func (secure *SecureInternetHomeServer) GetTemplateAuth() (func(string) string) {
+ return func(authURL string) string {
+ return util.ReplaceWAYF(secure.AuthorizationTemplate, authURL, secure.HomeOrganizationID)
+ }
+}
+
func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) {
return &institute.Base, nil
}
@@ -115,10 +130,10 @@ func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) {
return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetMapNotFoundError{}}
}
- base, exists := server.BaseMap[server.CurrentURL]
+ base, exists := server.BaseMap[server.CurrentLocation]
if !exists {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL}}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentLocation}}
}
return base, nil
}
@@ -137,23 +152,27 @@ func (institute *InstituteAccessServer) init(url string, fsm *fsm.FSM, logger *l
return nil
}
-func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error {
- errorMessage := fmt.Sprintf("failed initializing secure internet home server %s", url)
+func (servers *Servers) HasSecureLocation() bool {
+ return servers.SecureInternetHomeServer.CurrentLocation != ""
+}
+
+func (secure *SecureInternetHomeServer) addLocation(locationServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (*ServerBase, error) {
+ errorMessage := "failed adding a location"
// Initialize the base map if it is non-nil
if secure.BaseMap == nil {
secure.BaseMap = make(map[string]*ServerBase)
}
- // Add it if not present
- base, exists := secure.BaseMap[url]
+ // Add the location to the base map
+ base, exists := secure.BaseMap[locationServer.CountryCode]
if !exists || base == nil {
// Create the base to be added to the map
base = &ServerBase{}
- base.URL = url
- endpoints, endpointsErr := APIGetEndpoints(url)
+ base.URL = locationServer.BaseURL
+ endpoints, endpointsErr := APIGetEndpoints(locationServer.BaseURL)
if endpointsErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr}
}
base.Endpoints = *endpoints
}
@@ -163,19 +182,34 @@ func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *l
base.Logger = logger
// Ensure it is in the map
- secure.BaseMap[url] = base
+ secure.BaseMap[locationServer.CountryCode] = base
+ return base, nil
+}
+
- // Set the home url if it is not set yet
- if secure.HomeURL == "" {
- secure.HomeURL = url
- // Make sure oauth contains our endpoints
- secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm, logger)
- } else { // Else just pass in the fsm and logger
- secure.OAuth.Update(fsm, logger)
+// Initializes the home server and adds its own location
+func (secure *SecureInternetHomeServer) init(homeOrg *types.DiscoveryOrganization, homeLocation *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) error {
+ errorMessage := "failed initializing secure internet home server"
+
+ if secure.HomeOrganizationID != homeOrg.OrgId {
+ // New home organisation, clear everything
+ *secure = *&SecureInternetHomeServer{}
+ }
+
+
+ base, baseErr := secure.addLocation(homeLocation, fsm, logger)
+
+ if baseErr != nil {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
- // Set the current url
- secure.CurrentURL = url
+ // Make sure to set the organization ID
+ secure.HomeOrganizationID = homeOrg.OrgId
+
+ // Make sure to set the authorization URL template
+ secure.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate
+ // Make sure oauth contains our endpoints
+ secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm, logger)
return nil
}
@@ -211,7 +245,7 @@ func ShouldRenewButton(server Server) (bool, error) {
}
func Login(server Server) error {
- return server.GetOAuth().Login("org.eduvpn.app.linux")
+ return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth())
}
func EnsureTokens(server Server) error {
@@ -240,21 +274,9 @@ func CancelOAuth(server Server) {
server.GetOAuth().Cancel()
}
-func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) {
- errorMessage := "failed ensuring server"
- // Intialize the secure internet server
- // This calls the init method which takes care of the rest
- if isSecureInternet {
- initErr := servers.SecureInternetHomeServer.init(url, fsm, logger)
-
- if initErr != nil {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr}
- }
-
- servers.IsSecureInternet = true
- return &servers.SecureInternetHomeServer, nil
- }
-
+func (servers *Servers) AddInstituteAccess(instituteServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) {
+ url := instituteServer.BaseURL
+ errorMessage := fmt.Sprintf("failed adding institute access server: %s", url)
instituteServers := &servers.InstituteServers
if instituteServers.Map == nil {
@@ -279,6 +301,33 @@ func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm
return institute, nil
}
+func (servers *Servers) SetSecureLocation(chosenLocationServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) error {
+ errorMessage := "failed to set secure location"
+ // Make sure to add the current location
+ _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer, fsm, logger)
+
+ if addLocationErr != nil {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: addLocationErr}
+ }
+
+ servers.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode
+ return nil
+}
+
+func (servers *Servers) AddSecureInternet(secureOrg *types.DiscoveryOrganization, secureServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) {
+ errorMessage := "failed adding secure internet server"
+ // If we have specified an organization ID
+ // We also need to get an authorization template
+ initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer, fsm, logger)
+
+ if initErr != nil {
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr}
+ }
+
+ servers.IsSecureInternet = true
+ return &servers.SecureInternetHomeServer, nil
+}
+
type ServerProfile struct {
ID string `json:"profile_id"`
DisplayName string `json:"display_name"`
@@ -554,5 +603,5 @@ type ServerSecureInternetBaseNotFoundError struct {
}
func (e *ServerSecureInternetBaseNotFoundError) Error() string {
- return fmt.Sprintf("secure internet base not found with current: %s", e.Current)
+ return fmt.Sprintf("secure internet base not found with current location: %s", e.Current)
}
diff --git a/internal/types/server.go b/internal/types/server.go
new file mode 100644
index 0000000..ba9b217
--- /dev/null
+++ b/internal/types/server.go
@@ -0,0 +1,42 @@
+package types
+
+// Shared server types
+
+// Structs that define the json format for
+// url: "https://disco.eduvpn.org/v2/organization_list.json"
+type DiscoveryOrganizations struct {
+ Version uint64 `json:"v"`
+ List []DiscoveryOrganization `json:"organization_list"`
+ Timestamp int64 `json:"-"`
+ RawString string `json:"-"`
+}
+
+type DiscoveryOrganization struct {
+ DisplayName struct {
+ En string `json:"en"`
+ } `json:"display_name"`
+ OrgId string `json:"org_id"`
+ SecureInternetHome string `json:"secure_internet_home"`
+ KeywordList struct {
+ En string `json:"en"`
+ } `json:"keyword_list"`
+}
+
+// Structs that define the json format for
+// url: "https://disco.eduvpn.org/v2/server_list.json"
+type DiscoveryServers struct {
+ Version uint64 `json:"v"`
+ List []DiscoveryServer `json:"server_list"`
+ Timestamp int64 `json:"-"`
+ RawString string `json:"-"`
+}
+
+type DiscoveryServer struct {
+ AuthenticationURLTemplate string `json:"authentication_url_template"`
+ BaseURL string `json:"base_url"`
+ CountryCode string `json:"country_code"`
+ PublicKeyList []string `json:"public_key_list"`
+ Type string `json:"server_type"`
+ SupportContact []string `json:"support_contact"`
+}
+
diff --git a/internal/util/util.go b/internal/util/util.go
index 8dee61e..30767c3 100644
--- a/internal/util/util.go
+++ b/internal/util/util.go
@@ -4,8 +4,11 @@ import (
"crypto/rand"
"fmt"
"os"
+ "strings"
"time"
+ "net/url"
+
"github.com/jwijenbergh/eduvpn-common/internal/types"
)
@@ -31,3 +34,23 @@ func EnsureDirectory(directory string) error {
}
return nil
}
+
+// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md
+// URL encode for skipping where are you from (WAYF). Note that this right now is basically an alias to QueryEscape
+func WAYFEncode(input string) string {
+ // QueryReplace already replaces a space with a +
+ // see https://go.dev/play/p/pOfrn-Wsq5
+ return url.QueryEscape(input)
+}
+
+// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md
+func ReplaceWAYF(authTemplate string, authURL string, orgID string) string {
+ if authTemplate == "" {
+ return authURL
+ }
+ // Replace authURL
+ authTemplate = strings.Replace(authTemplate, "@RETURN_TO@", WAYFEncode(authURL), 1)
+ // Replace ORG ID
+ authTemplate = strings.Replace(authTemplate, "@ORG_ID@", WAYFEncode(orgID), 1)
+ return authTemplate
+}