diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-07-05 13:17:24 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-07-05 13:17:24 +0200 |
| commit | 1865b016d0cca74cd3703db5a3b4217917988dec (patch) | |
| tree | 3da84dbc4f1ad49221c25fb83f402d27deb34138 /internal | |
| parent | e39b9a8a405fa8e5f73c32bb03a3f349f7f9f92d (diff) | |
Refactor: Handling of different servers and identifiers
- Uses OrgID for Secure Internet and gets the data from discovery
- Uses URL for Institute/Custom and gets the data from discovery
- Implements SKIP WAYF as we now have the needed data
- Implements an initial change location with a default location (NL right now)
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/discovery/discovery.go | 145 | ||||
| -rw-r--r-- | internal/oauth/oauth.go | 13 | ||||
| -rw-r--r-- | internal/server/server.go | 135 | ||||
| -rw-r--r-- | internal/types/server.go | 42 | ||||
| -rw-r--r-- | internal/util/util.go | 23 |
5 files changed, 274 insertions, 84 deletions
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index ac3bf57..c61469d 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -12,27 +12,15 @@ import ( "github.com/jwijenbergh/eduvpn-common/internal/verify" ) -type OrganizationList struct { - JSON json.RawMessage `json:"organization_list"` - Version uint64 `json:"v"` - Timestamp int64 `json:"-"` -} - -type ServersList struct { - JSON json.RawMessage `json:"server_list"` - Version uint64 `json:"v"` - Timestamp int64 `json:"-"` -} - type Discovery struct { - Organizations OrganizationList - Servers ServersList + Organizations types.DiscoveryOrganizations + Servers types.DiscoveryServers FSM *fsm.FSM Logger *log.FileLogger } // Helper function that gets a disco json -func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) error { +func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) (string, error) { errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile) // Get json data discoURL := "https://disco.eduvpn.org/v2/" @@ -40,7 +28,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} _, fileBody, fileErr := http.HTTPGet(fileURL) if fileErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr} } // Get signature @@ -49,7 +37,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} _, sigBody, sigFileErr := http.HTTPGet(sigURL) if sigFileErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr} } // Verify signature @@ -58,17 +46,17 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} verifySuccess, verifyErr := verify.Verify(string(sigBody), fileBody, jsonFile, previousVersion, forcePrehash) if !verifySuccess || verifyErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr} } // Parse JSON to extract version and list jsonErr := json.Unmarshal(fileBody, structure) if jsonErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} } - return nil + return string(fileBody), nil } func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) { @@ -82,7 +70,63 @@ func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) { // - [TODO] when the user tries to add new server AND the user did NOT yet choose an organization before; // - [TODO] when the authorization for the server associated with an already chosen organization is triggered, e.g. after expiry or revocation. func (discovery *Discovery) DetermineOrganizationsUpdate() bool { - return string(discovery.Organizations.JSON) == "" + return discovery.Organizations.Timestamp == 0 +} + +func (discovery *Discovery) GetSecureLocationList() []string { + var locations []string + for _, server := range discovery.Servers.List { + if server.Type == "secure_internet" { + locations = append(locations, server.CountryCode) + } + } + return locations +} + +func (discovery *Discovery) GetServerByURL(url string, _type string) (*types.DiscoveryServer, error) { + for _, server := range discovery.Servers.List { + if server.BaseURL == url && server.Type == _type { + return &server, nil + } + } + return nil, &types.WrappedErrorMessage{Message: "failed getting server by URL from discovery", Err: &GetServerByURLNotFoundError{URL: url, Type: _type}} +} + +func (discovery *Discovery) GetServerByCountryCode(code string, _type string) (*types.DiscoveryServer, error) { + for _, server := range discovery.Servers.List { + if server.CountryCode == code && server.Type == _type { + return &server, nil + } + } + return nil, &types.WrappedErrorMessage{Message: "failed getting server by country code from discovery", Err: &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type}} +} + +func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganization, error) { + for _, organization := range discovery.Organizations.List { + if organization.OrgId == orgID { + return &organization, nil + } + } + return nil, &types.WrappedErrorMessage{Message: "failed getting Secure Internet Home URL from discovery", Err: &GetOrgByIDNotFoundError{ID: orgID}} +} + +func (discovery *Discovery) GetSecureHomeArgs(orgID string) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) { + errorMessage := "failed getting Secure Internet Home arguments from discovery" + org, orgErr := discovery.getOrgByID(orgID) + + if orgErr != nil { + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: orgErr} + } + + // Get a server with the base url + url := org.SecureInternetHome + + server, serverErr := discovery.GetServerByURL(url, "secure_internet") + + if serverErr != nil { + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + return org, server, nil } // https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md @@ -90,7 +134,7 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool { // - The application MAY refresh the server_list.json periodically, e.g. once every hour. func (discovery *Discovery) DetermineServersUpdate() bool { // No servers, we should update - if string(discovery.Servers.JSON) == "" { + if discovery.Servers.Timestamp == 0 { return true } // 1 hour from the last update @@ -106,29 +150,66 @@ func (discovery *Discovery) DetermineServersUpdate() bool { // Get the organization list func (discovery *Discovery) GetOrganizationsList() (string, error) { if !discovery.DetermineOrganizationsUpdate() { - return string(discovery.Organizations.JSON), nil + return discovery.Organizations.RawString, nil } file := "organization_list.json" - err := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations) - if err != nil { + body, bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations) + if bodyErr != nil { // Return previous with an error - return string(discovery.Organizations.JSON), &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: err} + return discovery.Organizations.RawString, &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: bodyErr} } - return string(discovery.Organizations.JSON), nil + discovery.Organizations.RawString = body + discovery.Organizations.Timestamp = util.GenerateTimeSeconds() + return discovery.Organizations.RawString, nil } // Get the server list func (discovery *Discovery) GetServersList() (string, error) { if !discovery.DetermineServersUpdate() { - return string(discovery.Servers.JSON), nil + return discovery.Servers.RawString, nil } file := "server_list.json" - err := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers) - if err != nil { + body, bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers) + if bodyErr != nil { // Return previous with an error - return string(discovery.Servers.JSON), &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: err} + return discovery.Servers.RawString, &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: bodyErr} } // Update servers timestamp + discovery.Servers.RawString = body discovery.Servers.Timestamp = util.GenerateTimeSeconds() - return string(discovery.Servers.JSON), nil + return discovery.Servers.RawString, nil +} + +type GetOrgByIDNotFoundError struct { + ID string +} + +func (e GetOrgByIDNotFoundError) Error() string { + return fmt.Sprintf("No Secure Internet Home found in organizations with ID %s", e.ID) +} + +type GetServerByURLNotFoundError struct { + URL string + Type string +} + +func (e GetServerByURLNotFoundError) Error() string { + return fmt.Sprintf("No institute access server found in organizations with URL %s and type %s", e.URL, e.Type) +} + +type GetServerByCountryCodeNotFoundError struct { + CountryCode string + Type string +} + +func (e GetServerByCountryCodeNotFoundError) Error() string { + return fmt.Sprintf("No institute access server found in organizations with country code %s and type %s", e.CountryCode, e.Type) +} + +type GetSecureHomeArgsNotFoundError struct { + URL string +} + +func (e GetSecureHomeArgsNotFoundError) Error() string { + return fmt.Sprintf("No Secure Internet Home found with URL: %s", e.URL) } diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 824db90..ef1bed4 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -223,11 +223,6 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { } } -func (oauth *OAuth) Update(fsm *fsm.FSM, logger *log.FileLogger) { - oauth.FSM = fsm - oauth.Logger = logger -} - func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.FSM, logger *log.FileLogger) { oauth.BaseAuthorizationURL = baseAuthorizationURL oauth.TokenURL = tokenURL @@ -236,7 +231,7 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm. } // Starts the OAuth exchange for eduvpn. -func (oauth *OAuth) start(name string) error { +func (oauth *OAuth) start(name string, postprocessAuth func(string) string) error { errorMessage := "failed starting OAuth exchange" if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) { return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()} @@ -274,7 +269,7 @@ func (oauth *OAuth) start(name string) error { oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier} oauth.Session = oauthSession // Run the state callback in the background so that the user can login while we start the callback server - oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, authURL, true) + oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, postprocessAuth(authURL), true) return nil } @@ -298,9 +293,9 @@ func (oauth *OAuth) Cancel() { oauth.Session.Server.Shutdown(oauth.Session.Context) } -func (oauth *OAuth) Login(name string) error { +func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) error { errorMessage := "failed OAuth login" - authInitializeErr := oauth.start(name) + authInitializeErr := oauth.start(name, postprocessAuth) if authInitializeErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr} diff --git a/internal/server/server.go b/internal/server/server.go index 807bd09..5bc2ea1 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -36,14 +36,16 @@ type InstituteAccessServer struct { // A secure internet server which has its own OAuth tokens // It specifies the current location url it is connected to type SecureInternetHomeServer struct { + DisplayName string `json:"display_name"` OAuth oauth.OAuth `json:"oauth"` - // The home server has a list of info for each configured server + // The home server has a list of info for each configured server location BaseMap map[string]*ServerBase `json:"base_map"` - // We have the home url and the current url - HomeURL string `json:"home_url"` - CurrentURL string `json:"current_url"` + // We have the authorization URL template, the home organization ID and the current location + AuthorizationTemplate string `json:"authorization_template"` + HomeOrganizationID string `json:"home_organization_id"` + CurrentLocation string `json:"current_location"` } type InstituteServers struct { @@ -89,11 +91,11 @@ type Server interface { // Gets the current OAuth object GetOAuth() *oauth.OAuth + // Get the authorization URL template function + GetTemplateAuth() func(string) string + // Gets the server base GetBase() (*ServerBase, error) - - // initialize method - init(url string, fsm *fsm.FSM, logger *log.FileLogger) error } // For an institute, we can simply get the OAuth @@ -105,6 +107,19 @@ func (secure *SecureInternetHomeServer) GetOAuth() *oauth.OAuth { return &secure.OAuth } + +func (institute *InstituteAccessServer) GetTemplateAuth() (func(string) string) { + return func(authURL string) string { + return authURL + } +} + +func (secure *SecureInternetHomeServer) GetTemplateAuth() (func(string) string) { + return func(authURL string) string { + return util.ReplaceWAYF(secure.AuthorizationTemplate, authURL, secure.HomeOrganizationID) + } +} + func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) { return &institute.Base, nil } @@ -115,10 +130,10 @@ func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetMapNotFoundError{}} } - base, exists := server.BaseMap[server.CurrentURL] + base, exists := server.BaseMap[server.CurrentLocation] if !exists { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL}} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentLocation}} } return base, nil } @@ -137,23 +152,27 @@ func (institute *InstituteAccessServer) init(url string, fsm *fsm.FSM, logger *l return nil } -func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error { - errorMessage := fmt.Sprintf("failed initializing secure internet home server %s", url) +func (servers *Servers) HasSecureLocation() bool { + return servers.SecureInternetHomeServer.CurrentLocation != "" +} + +func (secure *SecureInternetHomeServer) addLocation(locationServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (*ServerBase, error) { + errorMessage := "failed adding a location" // Initialize the base map if it is non-nil if secure.BaseMap == nil { secure.BaseMap = make(map[string]*ServerBase) } - // Add it if not present - base, exists := secure.BaseMap[url] + // Add the location to the base map + base, exists := secure.BaseMap[locationServer.CountryCode] if !exists || base == nil { // Create the base to be added to the map base = &ServerBase{} - base.URL = url - endpoints, endpointsErr := APIGetEndpoints(url) + base.URL = locationServer.BaseURL + endpoints, endpointsErr := APIGetEndpoints(locationServer.BaseURL) if endpointsErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} } base.Endpoints = *endpoints } @@ -163,19 +182,34 @@ func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *l base.Logger = logger // Ensure it is in the map - secure.BaseMap[url] = base + secure.BaseMap[locationServer.CountryCode] = base + return base, nil +} + - // Set the home url if it is not set yet - if secure.HomeURL == "" { - secure.HomeURL = url - // Make sure oauth contains our endpoints - secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm, logger) - } else { // Else just pass in the fsm and logger - secure.OAuth.Update(fsm, logger) +// Initializes the home server and adds its own location +func (secure *SecureInternetHomeServer) init(homeOrg *types.DiscoveryOrganization, homeLocation *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) error { + errorMessage := "failed initializing secure internet home server" + + if secure.HomeOrganizationID != homeOrg.OrgId { + // New home organisation, clear everything + *secure = *&SecureInternetHomeServer{} + } + + + base, baseErr := secure.addLocation(homeLocation, fsm, logger) + + if baseErr != nil { + return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } - // Set the current url - secure.CurrentURL = url + // Make sure to set the organization ID + secure.HomeOrganizationID = homeOrg.OrgId + + // Make sure to set the authorization URL template + secure.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate + // Make sure oauth contains our endpoints + secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm, logger) return nil } @@ -211,7 +245,7 @@ func ShouldRenewButton(server Server) (bool, error) { } func Login(server Server) error { - return server.GetOAuth().Login("org.eduvpn.app.linux") + return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth()) } func EnsureTokens(server Server) error { @@ -240,21 +274,9 @@ func CancelOAuth(server Server) { server.GetOAuth().Cancel() } -func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) { - errorMessage := "failed ensuring server" - // Intialize the secure internet server - // This calls the init method which takes care of the rest - if isSecureInternet { - initErr := servers.SecureInternetHomeServer.init(url, fsm, logger) - - if initErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr} - } - - servers.IsSecureInternet = true - return &servers.SecureInternetHomeServer, nil - } - +func (servers *Servers) AddInstituteAccess(instituteServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) { + url := instituteServer.BaseURL + errorMessage := fmt.Sprintf("failed adding institute access server: %s", url) instituteServers := &servers.InstituteServers if instituteServers.Map == nil { @@ -279,6 +301,33 @@ func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm return institute, nil } +func (servers *Servers) SetSecureLocation(chosenLocationServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) error { + errorMessage := "failed to set secure location" + // Make sure to add the current location + _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer, fsm, logger) + + if addLocationErr != nil { + return &types.WrappedErrorMessage{Message: errorMessage, Err: addLocationErr} + } + + servers.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode + return nil +} + +func (servers *Servers) AddSecureInternet(secureOrg *types.DiscoveryOrganization, secureServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) { + errorMessage := "failed adding secure internet server" + // If we have specified an organization ID + // We also need to get an authorization template + initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer, fsm, logger) + + if initErr != nil { + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr} + } + + servers.IsSecureInternet = true + return &servers.SecureInternetHomeServer, nil +} + type ServerProfile struct { ID string `json:"profile_id"` DisplayName string `json:"display_name"` @@ -554,5 +603,5 @@ type ServerSecureInternetBaseNotFoundError struct { } func (e *ServerSecureInternetBaseNotFoundError) Error() string { - return fmt.Sprintf("secure internet base not found with current: %s", e.Current) + return fmt.Sprintf("secure internet base not found with current location: %s", e.Current) } diff --git a/internal/types/server.go b/internal/types/server.go new file mode 100644 index 0000000..ba9b217 --- /dev/null +++ b/internal/types/server.go @@ -0,0 +1,42 @@ +package types + +// Shared server types + +// Structs that define the json format for +// url: "https://disco.eduvpn.org/v2/organization_list.json" +type DiscoveryOrganizations struct { + Version uint64 `json:"v"` + List []DiscoveryOrganization `json:"organization_list"` + Timestamp int64 `json:"-"` + RawString string `json:"-"` +} + +type DiscoveryOrganization struct { + DisplayName struct { + En string `json:"en"` + } `json:"display_name"` + OrgId string `json:"org_id"` + SecureInternetHome string `json:"secure_internet_home"` + KeywordList struct { + En string `json:"en"` + } `json:"keyword_list"` +} + +// Structs that define the json format for +// url: "https://disco.eduvpn.org/v2/server_list.json" +type DiscoveryServers struct { + Version uint64 `json:"v"` + List []DiscoveryServer `json:"server_list"` + Timestamp int64 `json:"-"` + RawString string `json:"-"` +} + +type DiscoveryServer struct { + AuthenticationURLTemplate string `json:"authentication_url_template"` + BaseURL string `json:"base_url"` + CountryCode string `json:"country_code"` + PublicKeyList []string `json:"public_key_list"` + Type string `json:"server_type"` + SupportContact []string `json:"support_contact"` +} + diff --git a/internal/util/util.go b/internal/util/util.go index 8dee61e..30767c3 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -4,8 +4,11 @@ import ( "crypto/rand" "fmt" "os" + "strings" "time" + "net/url" + "github.com/jwijenbergh/eduvpn-common/internal/types" ) @@ -31,3 +34,23 @@ func EnsureDirectory(directory string) error { } return nil } + +// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md +// URL encode for skipping where are you from (WAYF). Note that this right now is basically an alias to QueryEscape +func WAYFEncode(input string) string { + // QueryReplace already replaces a space with a + + // see https://go.dev/play/p/pOfrn-Wsq5 + return url.QueryEscape(input) +} + +// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md +func ReplaceWAYF(authTemplate string, authURL string, orgID string) string { + if authTemplate == "" { + return authURL + } + // Replace authURL + authTemplate = strings.Replace(authTemplate, "@RETURN_TO@", WAYFEncode(authURL), 1) + // Replace ORG ID + authTemplate = strings.Replace(authTemplate, "@ORG_ID@", WAYFEncode(orgID), 1) + return authTemplate +} |
