diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-07-05 13:17:24 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-07-05 13:17:24 +0200 |
| commit | 1865b016d0cca74cd3703db5a3b4217917988dec (patch) | |
| tree | 3da84dbc4f1ad49221c25fb83f402d27deb34138 /internal/discovery | |
| parent | e39b9a8a405fa8e5f73c32bb03a3f349f7f9f92d (diff) | |
Refactor: Handling of different servers and identifiers
- Uses OrgID for Secure Internet and gets the data from discovery
- Uses URL for Institute/Custom and gets the data from discovery
- Implements SKIP WAYF as we now have the needed data
- Implements an initial change location with a default location (NL right now)
Diffstat (limited to 'internal/discovery')
| -rw-r--r-- | internal/discovery/discovery.go | 145 |
1 files changed, 113 insertions, 32 deletions
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index ac3bf57..c61469d 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -12,27 +12,15 @@ import ( "github.com/jwijenbergh/eduvpn-common/internal/verify" ) -type OrganizationList struct { - JSON json.RawMessage `json:"organization_list"` - Version uint64 `json:"v"` - Timestamp int64 `json:"-"` -} - -type ServersList struct { - JSON json.RawMessage `json:"server_list"` - Version uint64 `json:"v"` - Timestamp int64 `json:"-"` -} - type Discovery struct { - Organizations OrganizationList - Servers ServersList + Organizations types.DiscoveryOrganizations + Servers types.DiscoveryServers FSM *fsm.FSM Logger *log.FileLogger } // Helper function that gets a disco json -func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) error { +func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) (string, error) { errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile) // Get json data discoURL := "https://disco.eduvpn.org/v2/" @@ -40,7 +28,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} _, fileBody, fileErr := http.HTTPGet(fileURL) if fileErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr} } // Get signature @@ -49,7 +37,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} _, sigBody, sigFileErr := http.HTTPGet(sigURL) if sigFileErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr} } // Verify signature @@ -58,17 +46,17 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} verifySuccess, verifyErr := verify.Verify(string(sigBody), fileBody, jsonFile, previousVersion, forcePrehash) if !verifySuccess || verifyErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr} } // Parse JSON to extract version and list jsonErr := json.Unmarshal(fileBody, structure) if jsonErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} } - return nil + return string(fileBody), nil } func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) { @@ -82,7 +70,63 @@ func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) { // - [TODO] when the user tries to add new server AND the user did NOT yet choose an organization before; // - [TODO] when the authorization for the server associated with an already chosen organization is triggered, e.g. after expiry or revocation. func (discovery *Discovery) DetermineOrganizationsUpdate() bool { - return string(discovery.Organizations.JSON) == "" + return discovery.Organizations.Timestamp == 0 +} + +func (discovery *Discovery) GetSecureLocationList() []string { + var locations []string + for _, server := range discovery.Servers.List { + if server.Type == "secure_internet" { + locations = append(locations, server.CountryCode) + } + } + return locations +} + +func (discovery *Discovery) GetServerByURL(url string, _type string) (*types.DiscoveryServer, error) { + for _, server := range discovery.Servers.List { + if server.BaseURL == url && server.Type == _type { + return &server, nil + } + } + return nil, &types.WrappedErrorMessage{Message: "failed getting server by URL from discovery", Err: &GetServerByURLNotFoundError{URL: url, Type: _type}} +} + +func (discovery *Discovery) GetServerByCountryCode(code string, _type string) (*types.DiscoveryServer, error) { + for _, server := range discovery.Servers.List { + if server.CountryCode == code && server.Type == _type { + return &server, nil + } + } + return nil, &types.WrappedErrorMessage{Message: "failed getting server by country code from discovery", Err: &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type}} +} + +func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganization, error) { + for _, organization := range discovery.Organizations.List { + if organization.OrgId == orgID { + return &organization, nil + } + } + return nil, &types.WrappedErrorMessage{Message: "failed getting Secure Internet Home URL from discovery", Err: &GetOrgByIDNotFoundError{ID: orgID}} +} + +func (discovery *Discovery) GetSecureHomeArgs(orgID string) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) { + errorMessage := "failed getting Secure Internet Home arguments from discovery" + org, orgErr := discovery.getOrgByID(orgID) + + if orgErr != nil { + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: orgErr} + } + + // Get a server with the base url + url := org.SecureInternetHome + + server, serverErr := discovery.GetServerByURL(url, "secure_internet") + + if serverErr != nil { + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + return org, server, nil } // https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md @@ -90,7 +134,7 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool { // - The application MAY refresh the server_list.json periodically, e.g. once every hour. func (discovery *Discovery) DetermineServersUpdate() bool { // No servers, we should update - if string(discovery.Servers.JSON) == "" { + if discovery.Servers.Timestamp == 0 { return true } // 1 hour from the last update @@ -106,29 +150,66 @@ func (discovery *Discovery) DetermineServersUpdate() bool { // Get the organization list func (discovery *Discovery) GetOrganizationsList() (string, error) { if !discovery.DetermineOrganizationsUpdate() { - return string(discovery.Organizations.JSON), nil + return discovery.Organizations.RawString, nil } file := "organization_list.json" - err := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations) - if err != nil { + body, bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations) + if bodyErr != nil { // Return previous with an error - return string(discovery.Organizations.JSON), &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: err} + return discovery.Organizations.RawString, &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: bodyErr} } - return string(discovery.Organizations.JSON), nil + discovery.Organizations.RawString = body + discovery.Organizations.Timestamp = util.GenerateTimeSeconds() + return discovery.Organizations.RawString, nil } // Get the server list func (discovery *Discovery) GetServersList() (string, error) { if !discovery.DetermineServersUpdate() { - return string(discovery.Servers.JSON), nil + return discovery.Servers.RawString, nil } file := "server_list.json" - err := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers) - if err != nil { + body, bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers) + if bodyErr != nil { // Return previous with an error - return string(discovery.Servers.JSON), &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: err} + return discovery.Servers.RawString, &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: bodyErr} } // Update servers timestamp + discovery.Servers.RawString = body discovery.Servers.Timestamp = util.GenerateTimeSeconds() - return string(discovery.Servers.JSON), nil + return discovery.Servers.RawString, nil +} + +type GetOrgByIDNotFoundError struct { + ID string +} + +func (e GetOrgByIDNotFoundError) Error() string { + return fmt.Sprintf("No Secure Internet Home found in organizations with ID %s", e.ID) +} + +type GetServerByURLNotFoundError struct { + URL string + Type string +} + +func (e GetServerByURLNotFoundError) Error() string { + return fmt.Sprintf("No institute access server found in organizations with URL %s and type %s", e.URL, e.Type) +} + +type GetServerByCountryCodeNotFoundError struct { + CountryCode string + Type string +} + +func (e GetServerByCountryCodeNotFoundError) Error() string { + return fmt.Sprintf("No institute access server found in organizations with country code %s and type %s", e.CountryCode, e.Type) +} + +type GetSecureHomeArgsNotFoundError struct { + URL string +} + +func (e GetSecureHomeArgsNotFoundError) Error() string { + return fmt.Sprintf("No Secure Internet Home found with URL: %s", e.URL) } |
