summaryrefslogtreecommitdiff
path: root/internal/discovery
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-07-05 13:17:24 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-07-05 13:17:24 +0200
commit1865b016d0cca74cd3703db5a3b4217917988dec (patch)
tree3da84dbc4f1ad49221c25fb83f402d27deb34138 /internal/discovery
parente39b9a8a405fa8e5f73c32bb03a3f349f7f9f92d (diff)
Refactor: Handling of different servers and identifiers
- Uses OrgID for Secure Internet and gets the data from discovery - Uses URL for Institute/Custom and gets the data from discovery - Implements SKIP WAYF as we now have the needed data - Implements an initial change location with a default location (NL right now)
Diffstat (limited to 'internal/discovery')
-rw-r--r--internal/discovery/discovery.go145
1 files changed, 113 insertions, 32 deletions
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index ac3bf57..c61469d 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -12,27 +12,15 @@ import (
"github.com/jwijenbergh/eduvpn-common/internal/verify"
)
-type OrganizationList struct {
- JSON json.RawMessage `json:"organization_list"`
- Version uint64 `json:"v"`
- Timestamp int64 `json:"-"`
-}
-
-type ServersList struct {
- JSON json.RawMessage `json:"server_list"`
- Version uint64 `json:"v"`
- Timestamp int64 `json:"-"`
-}
-
type Discovery struct {
- Organizations OrganizationList
- Servers ServersList
+ Organizations types.DiscoveryOrganizations
+ Servers types.DiscoveryServers
FSM *fsm.FSM
Logger *log.FileLogger
}
// Helper function that gets a disco json
-func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) error {
+func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) (string, error) {
errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile)
// Get json data
discoURL := "https://disco.eduvpn.org/v2/"
@@ -40,7 +28,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
_, fileBody, fileErr := http.HTTPGet(fileURL)
if fileErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr}
}
// Get signature
@@ -49,7 +37,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
_, sigBody, sigFileErr := http.HTTPGet(sigURL)
if sigFileErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr}
}
// Verify signature
@@ -58,17 +46,17 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
verifySuccess, verifyErr := verify.Verify(string(sigBody), fileBody, jsonFile, previousVersion, forcePrehash)
if !verifySuccess || verifyErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr}
}
// Parse JSON to extract version and list
jsonErr := json.Unmarshal(fileBody, structure)
if jsonErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
}
- return nil
+ return string(fileBody), nil
}
func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) {
@@ -82,7 +70,63 @@ func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) {
// - [TODO] when the user tries to add new server AND the user did NOT yet choose an organization before;
// - [TODO] when the authorization for the server associated with an already chosen organization is triggered, e.g. after expiry or revocation.
func (discovery *Discovery) DetermineOrganizationsUpdate() bool {
- return string(discovery.Organizations.JSON) == ""
+ return discovery.Organizations.Timestamp == 0
+}
+
+func (discovery *Discovery) GetSecureLocationList() []string {
+ var locations []string
+ for _, server := range discovery.Servers.List {
+ if server.Type == "secure_internet" {
+ locations = append(locations, server.CountryCode)
+ }
+ }
+ return locations
+}
+
+func (discovery *Discovery) GetServerByURL(url string, _type string) (*types.DiscoveryServer, error) {
+ for _, server := range discovery.Servers.List {
+ if server.BaseURL == url && server.Type == _type {
+ return &server, nil
+ }
+ }
+ return nil, &types.WrappedErrorMessage{Message: "failed getting server by URL from discovery", Err: &GetServerByURLNotFoundError{URL: url, Type: _type}}
+}
+
+func (discovery *Discovery) GetServerByCountryCode(code string, _type string) (*types.DiscoveryServer, error) {
+ for _, server := range discovery.Servers.List {
+ if server.CountryCode == code && server.Type == _type {
+ return &server, nil
+ }
+ }
+ return nil, &types.WrappedErrorMessage{Message: "failed getting server by country code from discovery", Err: &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type}}
+}
+
+func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganization, error) {
+ for _, organization := range discovery.Organizations.List {
+ if organization.OrgId == orgID {
+ return &organization, nil
+ }
+ }
+ return nil, &types.WrappedErrorMessage{Message: "failed getting Secure Internet Home URL from discovery", Err: &GetOrgByIDNotFoundError{ID: orgID}}
+}
+
+func (discovery *Discovery) GetSecureHomeArgs(orgID string) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) {
+ errorMessage := "failed getting Secure Internet Home arguments from discovery"
+ org, orgErr := discovery.getOrgByID(orgID)
+
+ if orgErr != nil {
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: orgErr}
+ }
+
+ // Get a server with the base url
+ url := org.SecureInternetHome
+
+ server, serverErr := discovery.GetServerByURL(url, "secure_internet")
+
+ if serverErr != nil {
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
+ }
+ return org, server, nil
}
// https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md
@@ -90,7 +134,7 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool {
// - The application MAY refresh the server_list.json periodically, e.g. once every hour.
func (discovery *Discovery) DetermineServersUpdate() bool {
// No servers, we should update
- if string(discovery.Servers.JSON) == "" {
+ if discovery.Servers.Timestamp == 0 {
return true
}
// 1 hour from the last update
@@ -106,29 +150,66 @@ func (discovery *Discovery) DetermineServersUpdate() bool {
// Get the organization list
func (discovery *Discovery) GetOrganizationsList() (string, error) {
if !discovery.DetermineOrganizationsUpdate() {
- return string(discovery.Organizations.JSON), nil
+ return discovery.Organizations.RawString, nil
}
file := "organization_list.json"
- err := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations)
- if err != nil {
+ body, bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations)
+ if bodyErr != nil {
// Return previous with an error
- return string(discovery.Organizations.JSON), &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: err}
+ return discovery.Organizations.RawString, &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: bodyErr}
}
- return string(discovery.Organizations.JSON), nil
+ discovery.Organizations.RawString = body
+ discovery.Organizations.Timestamp = util.GenerateTimeSeconds()
+ return discovery.Organizations.RawString, nil
}
// Get the server list
func (discovery *Discovery) GetServersList() (string, error) {
if !discovery.DetermineServersUpdate() {
- return string(discovery.Servers.JSON), nil
+ return discovery.Servers.RawString, nil
}
file := "server_list.json"
- err := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers)
- if err != nil {
+ body, bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers)
+ if bodyErr != nil {
// Return previous with an error
- return string(discovery.Servers.JSON), &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: err}
+ return discovery.Servers.RawString, &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: bodyErr}
}
// Update servers timestamp
+ discovery.Servers.RawString = body
discovery.Servers.Timestamp = util.GenerateTimeSeconds()
- return string(discovery.Servers.JSON), nil
+ return discovery.Servers.RawString, nil
+}
+
+type GetOrgByIDNotFoundError struct {
+ ID string
+}
+
+func (e GetOrgByIDNotFoundError) Error() string {
+ return fmt.Sprintf("No Secure Internet Home found in organizations with ID %s", e.ID)
+}
+
+type GetServerByURLNotFoundError struct {
+ URL string
+ Type string
+}
+
+func (e GetServerByURLNotFoundError) Error() string {
+ return fmt.Sprintf("No institute access server found in organizations with URL %s and type %s", e.URL, e.Type)
+}
+
+type GetServerByCountryCodeNotFoundError struct {
+ CountryCode string
+ Type string
+}
+
+func (e GetServerByCountryCodeNotFoundError) Error() string {
+ return fmt.Sprintf("No institute access server found in organizations with country code %s and type %s", e.CountryCode, e.Type)
+}
+
+type GetSecureHomeArgsNotFoundError struct {
+ URL string
+}
+
+func (e GetSecureHomeArgsNotFoundError) Error() string {
+ return fmt.Sprintf("No Secure Internet Home found with URL: %s", e.URL)
}