diff options
| -rw-r--r-- | cmd/cli/main.go | 35 | ||||
| -rw-r--r-- | exports/exports.go | 35 | ||||
| -rw-r--r-- | internal/discovery/discovery.go | 145 | ||||
| -rw-r--r-- | internal/oauth/oauth.go | 13 | ||||
| -rw-r--r-- | internal/server/server.go | 135 | ||||
| -rw-r--r-- | internal/types/server.go | 42 | ||||
| -rw-r--r-- | internal/util/util.go | 23 | ||||
| -rw-r--r-- | state.go | 171 | ||||
| -rw-r--r-- | state_test.go | 14 | ||||
| -rw-r--r-- | wrappers/python/main.py | 4 | ||||
| -rw-r--r-- | wrappers/python/src/__init__.py | 11 | ||||
| -rw-r--r-- | wrappers/python/src/main.py | 15 | ||||
| -rw-r--r-- | wrappers/python/tests.py | 2 |
13 files changed, 482 insertions, 163 deletions
diff --git a/cmd/cli/main.go b/cmd/cli/main.go index e6be0bf..93b81e0 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -13,6 +13,13 @@ import ( eduvpn "github.com/jwijenbergh/eduvpn-common" ) +type ServerTypes int8 +const ( + ServerTypeInstituteAccess ServerTypes = iota + ServerTypeSecureInternet + ServerTypeCustom +) + // Open a browser with xdg-open func openBrowser(urlString string) { fmt.Printf("OAuth: Initialized with AuthURL %s\n", urlString) @@ -89,16 +96,15 @@ func stateCallback(state *eduvpn.VPNState, oldState string, newState string, dat } // Get a config for Institute Access or Secure Internet Server -func getConfig(state *eduvpn.VPNState, url string, isInstitute bool) (string, string, error) { - if !strings.HasPrefix(url, "https://") { +func getConfig(state *eduvpn.VPNState, url string, serverType ServerTypes) (string, string, error) { + if !strings.HasPrefix(url, "http") { url = "https://" + url } - if !strings.HasSuffix(url, "/") { - url += "/" - } // Force TCP is set to False - if isInstitute { + if serverType == ServerTypeInstituteAccess { return state.GetConfigInstituteAccess(url, false) + } else if serverType == ServerTypeCustom { + return state.GetConfigCustomServer(url, false) } return state.GetConfigSecureInternet(url, false) } @@ -136,7 +142,7 @@ func storeSecureInternetConfig(state *eduvpn.VPNState, url string, directory str fmt.Println("Creating and storing cert for", url) - config, _, configErr := getConfig(state, url, false) + config, _, configErr := getConfig(state, url, ServerTypeSecureInternet) if configErr != nil { fmt.Printf("Failed obtaining config for url %s with error %v\n", url, configErr) @@ -193,7 +199,7 @@ func getSecureInternetAll(homeURL string) { } // Get a config for a single server, Institute Access or Secure Internet -func printConfig(url string, isInstitute bool) { +func printConfig(url string, serverType ServerTypes) { state := &eduvpn.VPNState{} state.Register("org.eduvpn.app.linux", "configs", func(old string, new string, data string) { @@ -202,7 +208,7 @@ func printConfig(url string, isInstitute bool) { defer state.Deregister() - config, _, configErr := getConfig(state, url, isInstitute) + config, _, configErr := getConfig(state, url, serverType) if configErr != nil { // Show the usage of tracebacks and causes @@ -217,20 +223,25 @@ func printConfig(url string, isInstitute bool) { // The main function // It parses the arguments and executes the correct functions func main() { + customUrlArg := flag.String("get-custom", "", "The url of a custom server to connect to") urlArg := flag.String("get-institute", "", "The url of an institute to connect to") secureInternet := flag.String("get-secure", "", "Gets secure internet servers.") secureInternetAll := flag.String("get-secure-all", "", "Gets certificates for all secure internet servers. It stores them in ./certs. Provide an URL for the home server e.g. nl.eduvpn.org.") flag.Parse() // Connect to a VPN by getting an Institute Access config + customUrlString := *customUrlArg urlString := *urlArg secureInternetString := *secureInternet secureInternetAllString := *secureInternetAll - if urlString != "" { - printConfig(urlString, true) + if customUrlString != "" { + printConfig(customUrlString, ServerTypeCustom) + return + } else if urlString != "" { + printConfig(urlString, ServerTypeInstituteAccess) return } else if secureInternetString != "" { - printConfig(secureInternetString, false) + printConfig(secureInternetString, ServerTypeSecureInternet) return } else if secureInternetAllString != "" { getSecureInternetAll(secureInternetAllString) diff --git a/exports/exports.go b/exports/exports.go index 567e189..6b33da6 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -107,22 +107,39 @@ func CancelOAuth(name *C.char) *C.char { return C.CString(cancelErrString) } -//export GetConnectConfig -func GetConnectConfig(name *C.char, url *C.char, isSecureInternet C.int, forceTCP C.int) (*C.char, *C.char, *C.char) { +//export GetConfigSecureInternet +func GetConfigSecureInternet(name *C.char, orgID *C.char, forceTCP C.int) (*C.char, *C.char, *C.char) { nameStr := C.GoString(name) state, stateErr := GetVPNState(nameStr) if stateErr != nil { return nil, nil, C.CString(ErrorToString(stateErr)) } - var config string - var configType string - var configErr error forceTCPBool := forceTCP == 1 - if isSecureInternet == 1 { - config, configType, configErr = state.GetConfigSecureInternet(C.GoString(url), forceTCPBool) - } else { - config, configType, configErr = state.GetConfigInstituteAccess(C.GoString(url), forceTCPBool) + config, configType, configErr := state.GetConfigSecureInternet(C.GoString(orgID), forceTCPBool) + return C.CString(config), C.CString(configType), C.CString(ErrorToString(configErr)) +} + +//export GetConfigInstituteAccess +func GetConfigInstituteAccess(name *C.char, url *C.char, forceTCP C.int) (*C.char, *C.char, *C.char) { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + if stateErr != nil { + return nil, nil, C.CString(ErrorToString(stateErr)) } + forceTCPBool := forceTCP == 1 + config, configType, configErr := state.GetConfigInstituteAccess(C.GoString(url), forceTCPBool) + return C.CString(config), C.CString(configType), C.CString(ErrorToString(configErr)) +} + +//export GetConfigCustomServer +func GetConfigCustomServer(name *C.char, url *C.char, forceTCP C.int) (*C.char, *C.char, *C.char) { + nameStr := C.GoString(name) + state, stateErr := GetVPNState(nameStr) + if stateErr != nil { + return nil, nil, C.CString(ErrorToString(stateErr)) + } + forceTCPBool := forceTCP == 1 + config, configType, configErr := state.GetConfigCustomServer(C.GoString(url), forceTCPBool) return C.CString(config), C.CString(configType), C.CString(ErrorToString(configErr)) } diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index ac3bf57..c61469d 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -12,27 +12,15 @@ import ( "github.com/jwijenbergh/eduvpn-common/internal/verify" ) -type OrganizationList struct { - JSON json.RawMessage `json:"organization_list"` - Version uint64 `json:"v"` - Timestamp int64 `json:"-"` -} - -type ServersList struct { - JSON json.RawMessage `json:"server_list"` - Version uint64 `json:"v"` - Timestamp int64 `json:"-"` -} - type Discovery struct { - Organizations OrganizationList - Servers ServersList + Organizations types.DiscoveryOrganizations + Servers types.DiscoveryServers FSM *fsm.FSM Logger *log.FileLogger } // Helper function that gets a disco json -func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) error { +func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) (string, error) { errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile) // Get json data discoURL := "https://disco.eduvpn.org/v2/" @@ -40,7 +28,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} _, fileBody, fileErr := http.HTTPGet(fileURL) if fileErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr} } // Get signature @@ -49,7 +37,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} _, sigBody, sigFileErr := http.HTTPGet(sigURL) if sigFileErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr} } // Verify signature @@ -58,17 +46,17 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} verifySuccess, verifyErr := verify.Verify(string(sigBody), fileBody, jsonFile, previousVersion, forcePrehash) if !verifySuccess || verifyErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr} } // Parse JSON to extract version and list jsonErr := json.Unmarshal(fileBody, structure) if jsonErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr} } - return nil + return string(fileBody), nil } func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) { @@ -82,7 +70,63 @@ func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) { // - [TODO] when the user tries to add new server AND the user did NOT yet choose an organization before; // - [TODO] when the authorization for the server associated with an already chosen organization is triggered, e.g. after expiry or revocation. func (discovery *Discovery) DetermineOrganizationsUpdate() bool { - return string(discovery.Organizations.JSON) == "" + return discovery.Organizations.Timestamp == 0 +} + +func (discovery *Discovery) GetSecureLocationList() []string { + var locations []string + for _, server := range discovery.Servers.List { + if server.Type == "secure_internet" { + locations = append(locations, server.CountryCode) + } + } + return locations +} + +func (discovery *Discovery) GetServerByURL(url string, _type string) (*types.DiscoveryServer, error) { + for _, server := range discovery.Servers.List { + if server.BaseURL == url && server.Type == _type { + return &server, nil + } + } + return nil, &types.WrappedErrorMessage{Message: "failed getting server by URL from discovery", Err: &GetServerByURLNotFoundError{URL: url, Type: _type}} +} + +func (discovery *Discovery) GetServerByCountryCode(code string, _type string) (*types.DiscoveryServer, error) { + for _, server := range discovery.Servers.List { + if server.CountryCode == code && server.Type == _type { + return &server, nil + } + } + return nil, &types.WrappedErrorMessage{Message: "failed getting server by country code from discovery", Err: &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type}} +} + +func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganization, error) { + for _, organization := range discovery.Organizations.List { + if organization.OrgId == orgID { + return &organization, nil + } + } + return nil, &types.WrappedErrorMessage{Message: "failed getting Secure Internet Home URL from discovery", Err: &GetOrgByIDNotFoundError{ID: orgID}} +} + +func (discovery *Discovery) GetSecureHomeArgs(orgID string) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) { + errorMessage := "failed getting Secure Internet Home arguments from discovery" + org, orgErr := discovery.getOrgByID(orgID) + + if orgErr != nil { + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: orgErr} + } + + // Get a server with the base url + url := org.SecureInternetHome + + server, serverErr := discovery.GetServerByURL(url, "secure_internet") + + if serverErr != nil { + return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + return org, server, nil } // https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md @@ -90,7 +134,7 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool { // - The application MAY refresh the server_list.json periodically, e.g. once every hour. func (discovery *Discovery) DetermineServersUpdate() bool { // No servers, we should update - if string(discovery.Servers.JSON) == "" { + if discovery.Servers.Timestamp == 0 { return true } // 1 hour from the last update @@ -106,29 +150,66 @@ func (discovery *Discovery) DetermineServersUpdate() bool { // Get the organization list func (discovery *Discovery) GetOrganizationsList() (string, error) { if !discovery.DetermineOrganizationsUpdate() { - return string(discovery.Organizations.JSON), nil + return discovery.Organizations.RawString, nil } file := "organization_list.json" - err := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations) - if err != nil { + body, bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations) + if bodyErr != nil { // Return previous with an error - return string(discovery.Organizations.JSON), &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: err} + return discovery.Organizations.RawString, &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: bodyErr} } - return string(discovery.Organizations.JSON), nil + discovery.Organizations.RawString = body + discovery.Organizations.Timestamp = util.GenerateTimeSeconds() + return discovery.Organizations.RawString, nil } // Get the server list func (discovery *Discovery) GetServersList() (string, error) { if !discovery.DetermineServersUpdate() { - return string(discovery.Servers.JSON), nil + return discovery.Servers.RawString, nil } file := "server_list.json" - err := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers) - if err != nil { + body, bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers) + if bodyErr != nil { // Return previous with an error - return string(discovery.Servers.JSON), &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: err} + return discovery.Servers.RawString, &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: bodyErr} } // Update servers timestamp + discovery.Servers.RawString = body discovery.Servers.Timestamp = util.GenerateTimeSeconds() - return string(discovery.Servers.JSON), nil + return discovery.Servers.RawString, nil +} + +type GetOrgByIDNotFoundError struct { + ID string +} + +func (e GetOrgByIDNotFoundError) Error() string { + return fmt.Sprintf("No Secure Internet Home found in organizations with ID %s", e.ID) +} + +type GetServerByURLNotFoundError struct { + URL string + Type string +} + +func (e GetServerByURLNotFoundError) Error() string { + return fmt.Sprintf("No institute access server found in organizations with URL %s and type %s", e.URL, e.Type) +} + +type GetServerByCountryCodeNotFoundError struct { + CountryCode string + Type string +} + +func (e GetServerByCountryCodeNotFoundError) Error() string { + return fmt.Sprintf("No institute access server found in organizations with country code %s and type %s", e.CountryCode, e.Type) +} + +type GetSecureHomeArgsNotFoundError struct { + URL string +} + +func (e GetSecureHomeArgsNotFoundError) Error() string { + return fmt.Sprintf("No Secure Internet Home found with URL: %s", e.URL) } diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 824db90..ef1bed4 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -223,11 +223,6 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { } } -func (oauth *OAuth) Update(fsm *fsm.FSM, logger *log.FileLogger) { - oauth.FSM = fsm - oauth.Logger = logger -} - func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.FSM, logger *log.FileLogger) { oauth.BaseAuthorizationURL = baseAuthorizationURL oauth.TokenURL = tokenURL @@ -236,7 +231,7 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm. } // Starts the OAuth exchange for eduvpn. -func (oauth *OAuth) start(name string) error { +func (oauth *OAuth) start(name string, postprocessAuth func(string) string) error { errorMessage := "failed starting OAuth exchange" if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) { return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()} @@ -274,7 +269,7 @@ func (oauth *OAuth) start(name string) error { oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier} oauth.Session = oauthSession // Run the state callback in the background so that the user can login while we start the callback server - oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, authURL, true) + oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, postprocessAuth(authURL), true) return nil } @@ -298,9 +293,9 @@ func (oauth *OAuth) Cancel() { oauth.Session.Server.Shutdown(oauth.Session.Context) } -func (oauth *OAuth) Login(name string) error { +func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) error { errorMessage := "failed OAuth login" - authInitializeErr := oauth.start(name) + authInitializeErr := oauth.start(name, postprocessAuth) if authInitializeErr != nil { return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr} diff --git a/internal/server/server.go b/internal/server/server.go index 807bd09..5bc2ea1 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -36,14 +36,16 @@ type InstituteAccessServer struct { // A secure internet server which has its own OAuth tokens // It specifies the current location url it is connected to type SecureInternetHomeServer struct { + DisplayName string `json:"display_name"` OAuth oauth.OAuth `json:"oauth"` - // The home server has a list of info for each configured server + // The home server has a list of info for each configured server location BaseMap map[string]*ServerBase `json:"base_map"` - // We have the home url and the current url - HomeURL string `json:"home_url"` - CurrentURL string `json:"current_url"` + // We have the authorization URL template, the home organization ID and the current location + AuthorizationTemplate string `json:"authorization_template"` + HomeOrganizationID string `json:"home_organization_id"` + CurrentLocation string `json:"current_location"` } type InstituteServers struct { @@ -89,11 +91,11 @@ type Server interface { // Gets the current OAuth object GetOAuth() *oauth.OAuth + // Get the authorization URL template function + GetTemplateAuth() func(string) string + // Gets the server base GetBase() (*ServerBase, error) - - // initialize method - init(url string, fsm *fsm.FSM, logger *log.FileLogger) error } // For an institute, we can simply get the OAuth @@ -105,6 +107,19 @@ func (secure *SecureInternetHomeServer) GetOAuth() *oauth.OAuth { return &secure.OAuth } + +func (institute *InstituteAccessServer) GetTemplateAuth() (func(string) string) { + return func(authURL string) string { + return authURL + } +} + +func (secure *SecureInternetHomeServer) GetTemplateAuth() (func(string) string) { + return func(authURL string) string { + return util.ReplaceWAYF(secure.AuthorizationTemplate, authURL, secure.HomeOrganizationID) + } +} + func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) { return &institute.Base, nil } @@ -115,10 +130,10 @@ func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) { return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetMapNotFoundError{}} } - base, exists := server.BaseMap[server.CurrentURL] + base, exists := server.BaseMap[server.CurrentLocation] if !exists { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL}} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentLocation}} } return base, nil } @@ -137,23 +152,27 @@ func (institute *InstituteAccessServer) init(url string, fsm *fsm.FSM, logger *l return nil } -func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error { - errorMessage := fmt.Sprintf("failed initializing secure internet home server %s", url) +func (servers *Servers) HasSecureLocation() bool { + return servers.SecureInternetHomeServer.CurrentLocation != "" +} + +func (secure *SecureInternetHomeServer) addLocation(locationServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (*ServerBase, error) { + errorMessage := "failed adding a location" // Initialize the base map if it is non-nil if secure.BaseMap == nil { secure.BaseMap = make(map[string]*ServerBase) } - // Add it if not present - base, exists := secure.BaseMap[url] + // Add the location to the base map + base, exists := secure.BaseMap[locationServer.CountryCode] if !exists || base == nil { // Create the base to be added to the map base = &ServerBase{} - base.URL = url - endpoints, endpointsErr := APIGetEndpoints(url) + base.URL = locationServer.BaseURL + endpoints, endpointsErr := APIGetEndpoints(locationServer.BaseURL) if endpointsErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr} } base.Endpoints = *endpoints } @@ -163,19 +182,34 @@ func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *l base.Logger = logger // Ensure it is in the map - secure.BaseMap[url] = base + secure.BaseMap[locationServer.CountryCode] = base + return base, nil +} + - // Set the home url if it is not set yet - if secure.HomeURL == "" { - secure.HomeURL = url - // Make sure oauth contains our endpoints - secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm, logger) - } else { // Else just pass in the fsm and logger - secure.OAuth.Update(fsm, logger) +// Initializes the home server and adds its own location +func (secure *SecureInternetHomeServer) init(homeOrg *types.DiscoveryOrganization, homeLocation *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) error { + errorMessage := "failed initializing secure internet home server" + + if secure.HomeOrganizationID != homeOrg.OrgId { + // New home organisation, clear everything + *secure = *&SecureInternetHomeServer{} + } + + + base, baseErr := secure.addLocation(homeLocation, fsm, logger) + + if baseErr != nil { + return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr} } - // Set the current url - secure.CurrentURL = url + // Make sure to set the organization ID + secure.HomeOrganizationID = homeOrg.OrgId + + // Make sure to set the authorization URL template + secure.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate + // Make sure oauth contains our endpoints + secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm, logger) return nil } @@ -211,7 +245,7 @@ func ShouldRenewButton(server Server) (bool, error) { } func Login(server Server) error { - return server.GetOAuth().Login("org.eduvpn.app.linux") + return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth()) } func EnsureTokens(server Server) error { @@ -240,21 +274,9 @@ func CancelOAuth(server Server) { server.GetOAuth().Cancel() } -func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) { - errorMessage := "failed ensuring server" - // Intialize the secure internet server - // This calls the init method which takes care of the rest - if isSecureInternet { - initErr := servers.SecureInternetHomeServer.init(url, fsm, logger) - - if initErr != nil { - return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr} - } - - servers.IsSecureInternet = true - return &servers.SecureInternetHomeServer, nil - } - +func (servers *Servers) AddInstituteAccess(instituteServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) { + url := instituteServer.BaseURL + errorMessage := fmt.Sprintf("failed adding institute access server: %s", url) instituteServers := &servers.InstituteServers if instituteServers.Map == nil { @@ -279,6 +301,33 @@ func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm return institute, nil } +func (servers *Servers) SetSecureLocation(chosenLocationServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) error { + errorMessage := "failed to set secure location" + // Make sure to add the current location + _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer, fsm, logger) + + if addLocationErr != nil { + return &types.WrappedErrorMessage{Message: errorMessage, Err: addLocationErr} + } + + servers.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode + return nil +} + +func (servers *Servers) AddSecureInternet(secureOrg *types.DiscoveryOrganization, secureServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) { + errorMessage := "failed adding secure internet server" + // If we have specified an organization ID + // We also need to get an authorization template + initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer, fsm, logger) + + if initErr != nil { + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr} + } + + servers.IsSecureInternet = true + return &servers.SecureInternetHomeServer, nil +} + type ServerProfile struct { ID string `json:"profile_id"` DisplayName string `json:"display_name"` @@ -554,5 +603,5 @@ type ServerSecureInternetBaseNotFoundError struct { } func (e *ServerSecureInternetBaseNotFoundError) Error() string { - return fmt.Sprintf("secure internet base not found with current: %s", e.Current) + return fmt.Sprintf("secure internet base not found with current location: %s", e.Current) } diff --git a/internal/types/server.go b/internal/types/server.go new file mode 100644 index 0000000..ba9b217 --- /dev/null +++ b/internal/types/server.go @@ -0,0 +1,42 @@ +package types + +// Shared server types + +// Structs that define the json format for +// url: "https://disco.eduvpn.org/v2/organization_list.json" +type DiscoveryOrganizations struct { + Version uint64 `json:"v"` + List []DiscoveryOrganization `json:"organization_list"` + Timestamp int64 `json:"-"` + RawString string `json:"-"` +} + +type DiscoveryOrganization struct { + DisplayName struct { + En string `json:"en"` + } `json:"display_name"` + OrgId string `json:"org_id"` + SecureInternetHome string `json:"secure_internet_home"` + KeywordList struct { + En string `json:"en"` + } `json:"keyword_list"` +} + +// Structs that define the json format for +// url: "https://disco.eduvpn.org/v2/server_list.json" +type DiscoveryServers struct { + Version uint64 `json:"v"` + List []DiscoveryServer `json:"server_list"` + Timestamp int64 `json:"-"` + RawString string `json:"-"` +} + +type DiscoveryServer struct { + AuthenticationURLTemplate string `json:"authentication_url_template"` + BaseURL string `json:"base_url"` + CountryCode string `json:"country_code"` + PublicKeyList []string `json:"public_key_list"` + Type string `json:"server_type"` + SupportContact []string `json:"support_contact"` +} + diff --git a/internal/util/util.go b/internal/util/util.go index 8dee61e..30767c3 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -4,8 +4,11 @@ import ( "crypto/rand" "fmt" "os" + "strings" "time" + "net/url" + "github.com/jwijenbergh/eduvpn-common/internal/types" ) @@ -31,3 +34,23 @@ func EnsureDirectory(directory string) error { } return nil } + +// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md +// URL encode for skipping where are you from (WAYF). Note that this right now is basically an alias to QueryEscape +func WAYFEncode(input string) string { + // QueryReplace already replaces a space with a + + // see https://go.dev/play/p/pOfrn-Wsq5 + return url.QueryEscape(input) +} + +// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md +func ReplaceWAYF(authTemplate string, authURL string, orgID string) string { + if authTemplate == "" { + return authURL + } + // Replace authURL + authTemplate = strings.Replace(authTemplate, "@RETURN_TO@", WAYFEncode(authURL), 1) + // Replace ORG ID + authTemplate = strings.Replace(authTemplate, "@ORG_ID@", WAYFEncode(orgID), 1) + return authTemplate +} @@ -1,6 +1,7 @@ package eduvpn import ( + "fmt" "github.com/jwijenbergh/eduvpn-common/internal/config" "github.com/jwijenbergh/eduvpn-common/internal/discovery" "github.com/jwijenbergh/eduvpn-common/internal/fsm" @@ -77,6 +78,9 @@ func (state *VPNState) Register(name string, directory string, stateCallback fun // Go to the No Server state with the saved servers state.FSM.GoTransitionWithData(fsm.NO_SERVER, state.GetSavedServers(), false) + + state.GetDiscoServers() + state.GetDiscoOrganizations() return nil } @@ -92,51 +96,12 @@ func (state *VPNState) Deregister() error { return nil } -func (state *VPNState) CancelOAuth() error { - errorMessage := "failed to cancel OAuth" - if !state.FSM.InState(fsm.OAUTH_STARTED) { - return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: state.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()} - } - - currentServer, serverErr := state.Servers.GetCurrentServer() - - if serverErr != nil { - return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} - } - server.CancelOAuth(currentServer) - return nil -} - -func (state *VPNState) chooseServer(url string, isSecureInternet bool) (server.Server, error) { - // New server chosen, ensure the server is fresh - server, serverErr := state.Servers.EnsureServer(url, isSecureInternet, &state.FSM, &state.Logger) - - if serverErr != nil { - return nil, &types.WrappedErrorMessage{Message: "failed to choose server", Err: serverErr} - } - - // Make sure we are in the chosen state if available - state.FSM.GoTransition(fsm.CHOSEN_SERVER) - return server, nil -} - -func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, forceTCP bool) (string, string, error) { +func (state *VPNState) getConfig(chosenServer server.Server, forceTCP bool) (string, string, error) { errorMessage := "failed to get a configuration for OpenVPN/Wireguard" if state.FSM.InState(fsm.DEREGISTERED) { return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.DeregisteredError{}.CustomError()} } - // Go to no server if possible, else return an error - if !state.FSM.InState(fsm.NO_SERVER) && !state.FSM.GoTransition(fsm.NO_SERVER) { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: state.FSM.Current, Want: fsm.NO_SERVER}.CustomError()} - } - - // Make sure the server is chosen - chosenServer, serverErr := state.chooseServer(url, isSecureInternet) - - if serverErr != nil { - return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} - } // Relogin with oauth // This moves the state to authorized if server.NeedsRelogin(chosenServer) { @@ -167,14 +132,134 @@ func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, f return config, configType, nil } +func (state *VPNState) AskSecureLocation() error { + fmt.Println("locations: ", state.Discovery.GetSecureLocationList()) + server, serverErr := state.Discovery.GetServerByCountryCode("NL", "secure_internet") + + if serverErr != nil { + return &types.WrappedErrorMessage{Message: "failed asking secure location", Err: serverErr} + } + + state.Servers.SetSecureLocation(server, &state.FSM, &state.Logger) + + return nil +} + + +func (state *VPNState) addSecureInternetHomeServer(orgID string) (server.Server, error) { + errorMessage := fmt.Sprintf("failed adding Secure Internet home server with organization ID %s", orgID) + // Get the secure internet URL from discovery + secureOrg, secureServer, discoErr := state.Discovery.GetSecureHomeArgs(orgID) + if discoErr != nil { + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr} + } + + // Add the secure internet server + server, serverErr := state.Servers.AddSecureInternet(secureOrg, secureServer, &state.FSM, &state.Logger) + + if serverErr != nil { + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + + if !state.Servers.HasSecureLocation() { + locationErr := state.AskSecureLocation() + + if locationErr != nil { + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: locationErr} + } + } + + return server, nil +} + +func (state *VPNState) GetConfigSecureInternet(orgID string, forceTCP bool) (string, string, error) { + errorMessage := fmt.Sprintf("failed getting a configuration for Secure Internet organization %s", orgID) + server, serverErr := state.addSecureInternetHomeServer(orgID) + + if serverErr != nil { + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + + state.FSM.GoTransition(fsm.CHOSEN_SERVER) + + return state.getConfig(server, forceTCP) +} + +func (state *VPNState) addInstituteServer(url string) (server.Server, error) { + errorMessage := fmt.Sprintf("failed adding Institute Access server with url %s", url) + instituteServer, discoErr := state.Discovery.GetServerByURL(url, "institute_access") + if discoErr != nil { + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr} + } + // Add the secure internet server + server, serverErr := state.Servers.AddInstituteAccess(instituteServer, &state.FSM, &state.Logger) + + if serverErr != nil { + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + + state.FSM.GoTransition(fsm.CHOSEN_SERVER) + + return server, nil + +} + +func (state *VPNState) addCustomServer(url string) (server.Server, error) { + errorMessage := fmt.Sprintf("failed adding Custom server with url %s", url) + + instituteServer := &types.DiscoveryServer{BaseURL: url, CountryCode: "NL", Type: "custom", SupportContact: []string{"custom"}} + + // A custom server is just an institute access server under the hood + server, serverErr := state.Servers.AddInstituteAccess(instituteServer, &state.FSM, &state.Logger) + + if serverErr != nil { + return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + + state.FSM.GoTransition(fsm.CHOSEN_SERVER) + + return server, nil + +} + func (state *VPNState) GetConfigInstituteAccess(url string, forceTCP bool) (string, string, error) { - return state.getConfigWithOptions(url, false, forceTCP) + errorMessage := fmt.Sprintf("failed getting a configuration for Institute Access %s", url) + server, serverErr := state.addInstituteServer(url) + + if serverErr != nil { + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + + return state.getConfig(server, forceTCP) } -func (state *VPNState) GetConfigSecureInternet(url string, forceTCP bool) (string, string, error) { - return state.getConfigWithOptions(url, true, forceTCP) +func (state *VPNState) GetConfigCustomServer(url string, forceTCP bool) (string, string, error) { + errorMessage := fmt.Sprintf("failed getting a configuration for custom server %s", url) + server, serverErr := state.addCustomServer(url) + + if serverErr != nil { + return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + + return state.getConfig(server, forceTCP) } +func (state *VPNState) CancelOAuth() error { + errorMessage := "failed to cancel OAuth" + if !state.FSM.InState(fsm.OAUTH_STARTED) { + return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: state.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()} + } + + currentServer, serverErr := state.Servers.GetCurrentServer() + + if serverErr != nil { + return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr} + } + server.CancelOAuth(currentServer) + return nil +} + + func (state *VPNState) GetDiscoOrganizations() (string, error) { if state.FSM.InState(fsm.DEREGISTERED) { return "", &types.WrappedErrorMessage{Message: "failed to get the organizations with Discovery", Err: fsm.DeregisteredError{}.CustomError()} diff --git a/state_test.go b/state_test.go index 76f4763..c42b314 100644 --- a/state_test.go +++ b/state_test.go @@ -72,7 +72,7 @@ func Test_server(t *testing.T) { stateCallback(t, old, new, data, state) }, false) - _, _, configErr := state.GetConfigInstituteAccess(serverURI, false) + _, _, configErr := state.GetConfigCustomServer(serverURI, false) if configErr != nil { t.Fatalf("Connect error: %v", configErr) @@ -95,7 +95,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters httpw.URLParameters, } }, false) - _, _, configErr := state.GetConfigInstituteAccess(serverURI, false) + _, _, configErr := state.GetConfigCustomServer(serverURI, false) var wrappedErr *types.WrappedErrorMessage @@ -157,7 +157,7 @@ func Test_token_expired(t *testing.T) { stateCallback(t, old, new, data, state) }, false) - _, _, configErr := state.GetConfigInstituteAccess(serverURI, false) + _, _, configErr := state.GetConfigCustomServer(serverURI, false) if configErr != nil { t.Fatalf("Connect error before expired: %v", configErr) @@ -205,7 +205,7 @@ func Test_token_invalid(t *testing.T) { stateCallback(t, old, new, data, state) }, false) - _, _, configErr := state.GetConfigInstituteAccess(serverURI, false) + _, _, configErr := state.GetConfigCustomServer(serverURI, false) if configErr != nil { t.Fatalf("Connect error before invalid: %v", configErr) @@ -229,7 +229,7 @@ func Test_token_invalid(t *testing.T) { oauth.Token.Access = dummy_value oauth.Token.Refresh = dummy_value - _, _, configErr = state.GetConfigInstituteAccess(serverURI, false) + _, _, configErr = state.GetConfigCustomServer(serverURI, false) if configErr != nil { t.Fatalf("Connect error after invalid: %v", configErr) @@ -255,7 +255,7 @@ func Test_invalid_profile_corrected(t *testing.T) { stateCallback(t, old, new, data, state) }, false) - _, _, configErr := state.GetConfigInstituteAccess(serverURI, false) + _, _, configErr := state.GetConfigCustomServer(serverURI, false) if configErr != nil { t.Fatalf("First connect error: %v", configErr) @@ -274,7 +274,7 @@ func Test_invalid_profile_corrected(t *testing.T) { previousProfile := base.Profiles.Current base.Profiles.Current = "IDONOTEXIST" - _, _, configErr = state.GetConfigInstituteAccess(serverURI, false) + _, _, configErr = state.GetConfigCustomServer(serverURI, false) if configErr != nil { t.Fatalf("Second connect error: %v", configErr) diff --git a/wrappers/python/main.py b/wrappers/python/main.py index a94281a..5422d93 100644 --- a/wrappers/python/main.py +++ b/wrappers/python/main.py @@ -68,7 +68,7 @@ if __name__ == "__main__": print("Failed registering:", e) server = input( - "Which Institute Access server do you want to connect to? (e.g. https://eduvpn.example.com): " + "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): " ) # Ensure we have a valid http prefix @@ -78,7 +78,7 @@ if __name__ == "__main__": # Get a Wireguard/OpenVPN config try: - config, config_type = _eduvpn.get_config_institute_access(server) + config, config_type = _eduvpn.get_config_custom_server(server) except Exception as e: print("Failed to connect:", e) print(f"Got a config with type: {config_type} and contents:\n{config}") diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py index 1ec0bec..f2ae66e 100644 --- a/wrappers/python/src/__init__.py +++ b/wrappers/python/src/__init__.py @@ -47,10 +47,19 @@ VPNStateChange = CFUNCTYPE(None, c_char_p, c_char_p, c_char_p, c_char_p) # Exposed functions # We have to use c_void_p instead of c_char_p to free it properly # See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error -lib.GetConnectConfig.argtypes, lib.GetConnectConfig.restype = [ +lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [ c_char_p, c_char_p, c_int, +], MultipleDataError +lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [ + c_char_p, + c_char_p, + c_int, +], MultipleDataError +lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [ + c_char_p, + c_char_p, c_int, ], MultipleDataError lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], c_void_p diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py index 76a08ab..dda3250 100644 --- a/wrappers/python/src/main.py +++ b/wrappers/python/src/main.py @@ -1,4 +1,5 @@ from . import lib, VPNStateChange, encode_args, decode_res +from enum import Enum from typing import Optional, Tuple import threading from .event import StateType, EventHandler @@ -90,14 +91,15 @@ class EduVPN(object): return organizations def get_config( - self, url: str, is_secure_internet: bool = False, force_tcp: bool = False + self, url: str, func: callable, force_tcp: bool = False ): # Because it could be the case that a profile callback is started, store a threading event # In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set # The event is set in self.set_profile self.profile_event = threading.Event() + config, config_type, config_err = self.go_function( - lib.GetConnectConfig, url, is_secure_internet, force_tcp + func, url, force_tcp ) if config_err: @@ -107,15 +109,20 @@ class EduVPN(object): return config, config_type + def get_config_custom_server( + self, url: str, force_tcp: bool = False + ) -> Tuple[str, str]: + return self.get_config(url, lib.GetConfigCustomServer, force_tcp) + def get_config_institute_access( self, url: str, force_tcp: bool = False ) -> Tuple[str, str]: - return self.get_config(url, False, force_tcp) + return self.get_config(url, lib.GetConfigInstituteAccess, force_tcp) def get_config_secure_internet( self, url: str, force_tcp: bool = False ) -> Tuple[str, str]: - return self.get_config(url, True, force_tcp) + return self.get_config(url, lib.GetConfigSecureInternet, force_tcp) def set_connected(self) -> None: connect_err = self.go_function(lib.SetConnected) diff --git a/wrappers/python/tests.py b/wrappers/python/tests.py index 7f17ef6..e58e67d 100644 --- a/wrappers/python/tests.py +++ b/wrappers/python/tests.py @@ -29,7 +29,7 @@ class ConfigTests(unittest.TestCase): self.fail("No SERVER_URI environment variable given") # This can throw an exception - _eduvpn.get_config_institute_access(server_uri) + _eduvpn.get_config_custom_server(server_uri) # Deregister _eduvpn.deregister() |
