summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-07-05 13:17:24 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-07-05 13:17:24 +0200
commit1865b016d0cca74cd3703db5a3b4217917988dec (patch)
tree3da84dbc4f1ad49221c25fb83f402d27deb34138
parente39b9a8a405fa8e5f73c32bb03a3f349f7f9f92d (diff)
Refactor: Handling of different servers and identifiers
- Uses OrgID for Secure Internet and gets the data from discovery - Uses URL for Institute/Custom and gets the data from discovery - Implements SKIP WAYF as we now have the needed data - Implements an initial change location with a default location (NL right now)
-rw-r--r--cmd/cli/main.go35
-rw-r--r--exports/exports.go35
-rw-r--r--internal/discovery/discovery.go145
-rw-r--r--internal/oauth/oauth.go13
-rw-r--r--internal/server/server.go135
-rw-r--r--internal/types/server.go42
-rw-r--r--internal/util/util.go23
-rw-r--r--state.go171
-rw-r--r--state_test.go14
-rw-r--r--wrappers/python/main.py4
-rw-r--r--wrappers/python/src/__init__.py11
-rw-r--r--wrappers/python/src/main.py15
-rw-r--r--wrappers/python/tests.py2
13 files changed, 482 insertions, 163 deletions
diff --git a/cmd/cli/main.go b/cmd/cli/main.go
index e6be0bf..93b81e0 100644
--- a/cmd/cli/main.go
+++ b/cmd/cli/main.go
@@ -13,6 +13,13 @@ import (
eduvpn "github.com/jwijenbergh/eduvpn-common"
)
+type ServerTypes int8
+const (
+ ServerTypeInstituteAccess ServerTypes = iota
+ ServerTypeSecureInternet
+ ServerTypeCustom
+)
+
// Open a browser with xdg-open
func openBrowser(urlString string) {
fmt.Printf("OAuth: Initialized with AuthURL %s\n", urlString)
@@ -89,16 +96,15 @@ func stateCallback(state *eduvpn.VPNState, oldState string, newState string, dat
}
// Get a config for Institute Access or Secure Internet Server
-func getConfig(state *eduvpn.VPNState, url string, isInstitute bool) (string, string, error) {
- if !strings.HasPrefix(url, "https://") {
+func getConfig(state *eduvpn.VPNState, url string, serverType ServerTypes) (string, string, error) {
+ if !strings.HasPrefix(url, "http") {
url = "https://" + url
}
- if !strings.HasSuffix(url, "/") {
- url += "/"
- }
// Force TCP is set to False
- if isInstitute {
+ if serverType == ServerTypeInstituteAccess {
return state.GetConfigInstituteAccess(url, false)
+ } else if serverType == ServerTypeCustom {
+ return state.GetConfigCustomServer(url, false)
}
return state.GetConfigSecureInternet(url, false)
}
@@ -136,7 +142,7 @@ func storeSecureInternetConfig(state *eduvpn.VPNState, url string, directory str
fmt.Println("Creating and storing cert for", url)
- config, _, configErr := getConfig(state, url, false)
+ config, _, configErr := getConfig(state, url, ServerTypeSecureInternet)
if configErr != nil {
fmt.Printf("Failed obtaining config for url %s with error %v\n", url, configErr)
@@ -193,7 +199,7 @@ func getSecureInternetAll(homeURL string) {
}
// Get a config for a single server, Institute Access or Secure Internet
-func printConfig(url string, isInstitute bool) {
+func printConfig(url string, serverType ServerTypes) {
state := &eduvpn.VPNState{}
state.Register("org.eduvpn.app.linux", "configs", func(old string, new string, data string) {
@@ -202,7 +208,7 @@ func printConfig(url string, isInstitute bool) {
defer state.Deregister()
- config, _, configErr := getConfig(state, url, isInstitute)
+ config, _, configErr := getConfig(state, url, serverType)
if configErr != nil {
// Show the usage of tracebacks and causes
@@ -217,20 +223,25 @@ func printConfig(url string, isInstitute bool) {
// The main function
// It parses the arguments and executes the correct functions
func main() {
+ customUrlArg := flag.String("get-custom", "", "The url of a custom server to connect to")
urlArg := flag.String("get-institute", "", "The url of an institute to connect to")
secureInternet := flag.String("get-secure", "", "Gets secure internet servers.")
secureInternetAll := flag.String("get-secure-all", "", "Gets certificates for all secure internet servers. It stores them in ./certs. Provide an URL for the home server e.g. nl.eduvpn.org.")
flag.Parse()
// Connect to a VPN by getting an Institute Access config
+ customUrlString := *customUrlArg
urlString := *urlArg
secureInternetString := *secureInternet
secureInternetAllString := *secureInternetAll
- if urlString != "" {
- printConfig(urlString, true)
+ if customUrlString != "" {
+ printConfig(customUrlString, ServerTypeCustom)
+ return
+ } else if urlString != "" {
+ printConfig(urlString, ServerTypeInstituteAccess)
return
} else if secureInternetString != "" {
- printConfig(secureInternetString, false)
+ printConfig(secureInternetString, ServerTypeSecureInternet)
return
} else if secureInternetAllString != "" {
getSecureInternetAll(secureInternetAllString)
diff --git a/exports/exports.go b/exports/exports.go
index 567e189..6b33da6 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -107,22 +107,39 @@ func CancelOAuth(name *C.char) *C.char {
return C.CString(cancelErrString)
}
-//export GetConnectConfig
-func GetConnectConfig(name *C.char, url *C.char, isSecureInternet C.int, forceTCP C.int) (*C.char, *C.char, *C.char) {
+//export GetConfigSecureInternet
+func GetConfigSecureInternet(name *C.char, orgID *C.char, forceTCP C.int) (*C.char, *C.char, *C.char) {
nameStr := C.GoString(name)
state, stateErr := GetVPNState(nameStr)
if stateErr != nil {
return nil, nil, C.CString(ErrorToString(stateErr))
}
- var config string
- var configType string
- var configErr error
forceTCPBool := forceTCP == 1
- if isSecureInternet == 1 {
- config, configType, configErr = state.GetConfigSecureInternet(C.GoString(url), forceTCPBool)
- } else {
- config, configType, configErr = state.GetConfigInstituteAccess(C.GoString(url), forceTCPBool)
+ config, configType, configErr := state.GetConfigSecureInternet(C.GoString(orgID), forceTCPBool)
+ return C.CString(config), C.CString(configType), C.CString(ErrorToString(configErr))
+}
+
+//export GetConfigInstituteAccess
+func GetConfigInstituteAccess(name *C.char, url *C.char, forceTCP C.int) (*C.char, *C.char, *C.char) {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ if stateErr != nil {
+ return nil, nil, C.CString(ErrorToString(stateErr))
}
+ forceTCPBool := forceTCP == 1
+ config, configType, configErr := state.GetConfigInstituteAccess(C.GoString(url), forceTCPBool)
+ return C.CString(config), C.CString(configType), C.CString(ErrorToString(configErr))
+}
+
+//export GetConfigCustomServer
+func GetConfigCustomServer(name *C.char, url *C.char, forceTCP C.int) (*C.char, *C.char, *C.char) {
+ nameStr := C.GoString(name)
+ state, stateErr := GetVPNState(nameStr)
+ if stateErr != nil {
+ return nil, nil, C.CString(ErrorToString(stateErr))
+ }
+ forceTCPBool := forceTCP == 1
+ config, configType, configErr := state.GetConfigCustomServer(C.GoString(url), forceTCPBool)
return C.CString(config), C.CString(configType), C.CString(ErrorToString(configErr))
}
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index ac3bf57..c61469d 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -12,27 +12,15 @@ import (
"github.com/jwijenbergh/eduvpn-common/internal/verify"
)
-type OrganizationList struct {
- JSON json.RawMessage `json:"organization_list"`
- Version uint64 `json:"v"`
- Timestamp int64 `json:"-"`
-}
-
-type ServersList struct {
- JSON json.RawMessage `json:"server_list"`
- Version uint64 `json:"v"`
- Timestamp int64 `json:"-"`
-}
-
type Discovery struct {
- Organizations OrganizationList
- Servers ServersList
+ Organizations types.DiscoveryOrganizations
+ Servers types.DiscoveryServers
FSM *fsm.FSM
Logger *log.FileLogger
}
// Helper function that gets a disco json
-func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) error {
+func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}) (string, error) {
errorMessage := fmt.Sprintf("failed getting file: %s from the Discovery server", jsonFile)
// Get json data
discoURL := "https://disco.eduvpn.org/v2/"
@@ -40,7 +28,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
_, fileBody, fileErr := http.HTTPGet(fileURL)
if fileErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: fileErr}
}
// Get signature
@@ -49,7 +37,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
_, sigBody, sigFileErr := http.HTTPGet(sigURL)
if sigFileErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: sigFileErr}
}
// Verify signature
@@ -58,17 +46,17 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{}
verifySuccess, verifyErr := verify.Verify(string(sigBody), fileBody, jsonFile, previousVersion, forcePrehash)
if !verifySuccess || verifyErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: verifyErr}
}
// Parse JSON to extract version and list
jsonErr := json.Unmarshal(fileBody, structure)
if jsonErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: jsonErr}
}
- return nil
+ return string(fileBody), nil
}
func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) {
@@ -82,7 +70,63 @@ func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) {
// - [TODO] when the user tries to add new server AND the user did NOT yet choose an organization before;
// - [TODO] when the authorization for the server associated with an already chosen organization is triggered, e.g. after expiry or revocation.
func (discovery *Discovery) DetermineOrganizationsUpdate() bool {
- return string(discovery.Organizations.JSON) == ""
+ return discovery.Organizations.Timestamp == 0
+}
+
+func (discovery *Discovery) GetSecureLocationList() []string {
+ var locations []string
+ for _, server := range discovery.Servers.List {
+ if server.Type == "secure_internet" {
+ locations = append(locations, server.CountryCode)
+ }
+ }
+ return locations
+}
+
+func (discovery *Discovery) GetServerByURL(url string, _type string) (*types.DiscoveryServer, error) {
+ for _, server := range discovery.Servers.List {
+ if server.BaseURL == url && server.Type == _type {
+ return &server, nil
+ }
+ }
+ return nil, &types.WrappedErrorMessage{Message: "failed getting server by URL from discovery", Err: &GetServerByURLNotFoundError{URL: url, Type: _type}}
+}
+
+func (discovery *Discovery) GetServerByCountryCode(code string, _type string) (*types.DiscoveryServer, error) {
+ for _, server := range discovery.Servers.List {
+ if server.CountryCode == code && server.Type == _type {
+ return &server, nil
+ }
+ }
+ return nil, &types.WrappedErrorMessage{Message: "failed getting server by country code from discovery", Err: &GetServerByCountryCodeNotFoundError{CountryCode: code, Type: _type}}
+}
+
+func (discovery *Discovery) getOrgByID(orgID string) (*types.DiscoveryOrganization, error) {
+ for _, organization := range discovery.Organizations.List {
+ if organization.OrgId == orgID {
+ return &organization, nil
+ }
+ }
+ return nil, &types.WrappedErrorMessage{Message: "failed getting Secure Internet Home URL from discovery", Err: &GetOrgByIDNotFoundError{ID: orgID}}
+}
+
+func (discovery *Discovery) GetSecureHomeArgs(orgID string) (*types.DiscoveryOrganization, *types.DiscoveryServer, error) {
+ errorMessage := "failed getting Secure Internet Home arguments from discovery"
+ org, orgErr := discovery.getOrgByID(orgID)
+
+ if orgErr != nil {
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: orgErr}
+ }
+
+ // Get a server with the base url
+ url := org.SecureInternetHome
+
+ server, serverErr := discovery.GetServerByURL(url, "secure_internet")
+
+ if serverErr != nil {
+ return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
+ }
+ return org, server, nil
}
// https://github.com/eduvpn/documentation/blob/v3/SERVER_DISCOVERY.md
@@ -90,7 +134,7 @@ func (discovery *Discovery) DetermineOrganizationsUpdate() bool {
// - The application MAY refresh the server_list.json periodically, e.g. once every hour.
func (discovery *Discovery) DetermineServersUpdate() bool {
// No servers, we should update
- if string(discovery.Servers.JSON) == "" {
+ if discovery.Servers.Timestamp == 0 {
return true
}
// 1 hour from the last update
@@ -106,29 +150,66 @@ func (discovery *Discovery) DetermineServersUpdate() bool {
// Get the organization list
func (discovery *Discovery) GetOrganizationsList() (string, error) {
if !discovery.DetermineOrganizationsUpdate() {
- return string(discovery.Organizations.JSON), nil
+ return discovery.Organizations.RawString, nil
}
file := "organization_list.json"
- err := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations)
- if err != nil {
+ body, bodyErr := getDiscoFile(file, discovery.Organizations.Version, &discovery.Organizations)
+ if bodyErr != nil {
// Return previous with an error
- return string(discovery.Organizations.JSON), &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: err}
+ return discovery.Organizations.RawString, &types.WrappedErrorMessage{Message: "failed getting organizations in Discovery", Err: bodyErr}
}
- return string(discovery.Organizations.JSON), nil
+ discovery.Organizations.RawString = body
+ discovery.Organizations.Timestamp = util.GenerateTimeSeconds()
+ return discovery.Organizations.RawString, nil
}
// Get the server list
func (discovery *Discovery) GetServersList() (string, error) {
if !discovery.DetermineServersUpdate() {
- return string(discovery.Servers.JSON), nil
+ return discovery.Servers.RawString, nil
}
file := "server_list.json"
- err := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers)
- if err != nil {
+ body, bodyErr := getDiscoFile(file, discovery.Servers.Version, &discovery.Servers)
+ if bodyErr != nil {
// Return previous with an error
- return string(discovery.Servers.JSON), &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: err}
+ return discovery.Servers.RawString, &types.WrappedErrorMessage{Message: "failed getting servers in Discovery", Err: bodyErr}
}
// Update servers timestamp
+ discovery.Servers.RawString = body
discovery.Servers.Timestamp = util.GenerateTimeSeconds()
- return string(discovery.Servers.JSON), nil
+ return discovery.Servers.RawString, nil
+}
+
+type GetOrgByIDNotFoundError struct {
+ ID string
+}
+
+func (e GetOrgByIDNotFoundError) Error() string {
+ return fmt.Sprintf("No Secure Internet Home found in organizations with ID %s", e.ID)
+}
+
+type GetServerByURLNotFoundError struct {
+ URL string
+ Type string
+}
+
+func (e GetServerByURLNotFoundError) Error() string {
+ return fmt.Sprintf("No institute access server found in organizations with URL %s and type %s", e.URL, e.Type)
+}
+
+type GetServerByCountryCodeNotFoundError struct {
+ CountryCode string
+ Type string
+}
+
+func (e GetServerByCountryCodeNotFoundError) Error() string {
+ return fmt.Sprintf("No institute access server found in organizations with country code %s and type %s", e.CountryCode, e.Type)
+}
+
+type GetSecureHomeArgsNotFoundError struct {
+ URL string
+}
+
+func (e GetSecureHomeArgsNotFoundError) Error() string {
+ return fmt.Sprintf("No Secure Internet Home found with URL: %s", e.URL)
}
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index 824db90..ef1bed4 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -223,11 +223,6 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) {
}
}
-func (oauth *OAuth) Update(fsm *fsm.FSM, logger *log.FileLogger) {
- oauth.FSM = fsm
- oauth.Logger = logger
-}
-
func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.FSM, logger *log.FileLogger) {
oauth.BaseAuthorizationURL = baseAuthorizationURL
oauth.TokenURL = tokenURL
@@ -236,7 +231,7 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.
}
// Starts the OAuth exchange for eduvpn.
-func (oauth *OAuth) start(name string) error {
+func (oauth *OAuth) start(name string, postprocessAuth func(string) string) error {
errorMessage := "failed starting OAuth exchange"
if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) {
return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()}
@@ -274,7 +269,7 @@ func (oauth *OAuth) start(name string) error {
oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier}
oauth.Session = oauthSession
// Run the state callback in the background so that the user can login while we start the callback server
- oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, authURL, true)
+ oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, postprocessAuth(authURL), true)
return nil
}
@@ -298,9 +293,9 @@ func (oauth *OAuth) Cancel() {
oauth.Session.Server.Shutdown(oauth.Session.Context)
}
-func (oauth *OAuth) Login(name string) error {
+func (oauth *OAuth) Login(name string, postprocessAuth func(string) string) error {
errorMessage := "failed OAuth login"
- authInitializeErr := oauth.start(name)
+ authInitializeErr := oauth.start(name, postprocessAuth)
if authInitializeErr != nil {
return &types.WrappedErrorMessage{Message: errorMessage, Err: authInitializeErr}
diff --git a/internal/server/server.go b/internal/server/server.go
index 807bd09..5bc2ea1 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -36,14 +36,16 @@ type InstituteAccessServer struct {
// A secure internet server which has its own OAuth tokens
// It specifies the current location url it is connected to
type SecureInternetHomeServer struct {
+ DisplayName string `json:"display_name"`
OAuth oauth.OAuth `json:"oauth"`
- // The home server has a list of info for each configured server
+ // The home server has a list of info for each configured server location
BaseMap map[string]*ServerBase `json:"base_map"`
- // We have the home url and the current url
- HomeURL string `json:"home_url"`
- CurrentURL string `json:"current_url"`
+ // We have the authorization URL template, the home organization ID and the current location
+ AuthorizationTemplate string `json:"authorization_template"`
+ HomeOrganizationID string `json:"home_organization_id"`
+ CurrentLocation string `json:"current_location"`
}
type InstituteServers struct {
@@ -89,11 +91,11 @@ type Server interface {
// Gets the current OAuth object
GetOAuth() *oauth.OAuth
+ // Get the authorization URL template function
+ GetTemplateAuth() func(string) string
+
// Gets the server base
GetBase() (*ServerBase, error)
-
- // initialize method
- init(url string, fsm *fsm.FSM, logger *log.FileLogger) error
}
// For an institute, we can simply get the OAuth
@@ -105,6 +107,19 @@ func (secure *SecureInternetHomeServer) GetOAuth() *oauth.OAuth {
return &secure.OAuth
}
+
+func (institute *InstituteAccessServer) GetTemplateAuth() (func(string) string) {
+ return func(authURL string) string {
+ return authURL
+ }
+}
+
+func (secure *SecureInternetHomeServer) GetTemplateAuth() (func(string) string) {
+ return func(authURL string) string {
+ return util.ReplaceWAYF(secure.AuthorizationTemplate, authURL, secure.HomeOrganizationID)
+ }
+}
+
func (institute *InstituteAccessServer) GetBase() (*ServerBase, error) {
return &institute.Base, nil
}
@@ -115,10 +130,10 @@ func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) {
return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetMapNotFoundError{}}
}
- base, exists := server.BaseMap[server.CurrentURL]
+ base, exists := server.BaseMap[server.CurrentLocation]
if !exists {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentURL}}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: &ServerSecureInternetBaseNotFoundError{Current: server.CurrentLocation}}
}
return base, nil
}
@@ -137,23 +152,27 @@ func (institute *InstituteAccessServer) init(url string, fsm *fsm.FSM, logger *l
return nil
}
-func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error {
- errorMessage := fmt.Sprintf("failed initializing secure internet home server %s", url)
+func (servers *Servers) HasSecureLocation() bool {
+ return servers.SecureInternetHomeServer.CurrentLocation != ""
+}
+
+func (secure *SecureInternetHomeServer) addLocation(locationServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (*ServerBase, error) {
+ errorMessage := "failed adding a location"
// Initialize the base map if it is non-nil
if secure.BaseMap == nil {
secure.BaseMap = make(map[string]*ServerBase)
}
- // Add it if not present
- base, exists := secure.BaseMap[url]
+ // Add the location to the base map
+ base, exists := secure.BaseMap[locationServer.CountryCode]
if !exists || base == nil {
// Create the base to be added to the map
base = &ServerBase{}
- base.URL = url
- endpoints, endpointsErr := APIGetEndpoints(url)
+ base.URL = locationServer.BaseURL
+ endpoints, endpointsErr := APIGetEndpoints(locationServer.BaseURL)
if endpointsErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr}
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: endpointsErr}
}
base.Endpoints = *endpoints
}
@@ -163,19 +182,34 @@ func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *l
base.Logger = logger
// Ensure it is in the map
- secure.BaseMap[url] = base
+ secure.BaseMap[locationServer.CountryCode] = base
+ return base, nil
+}
+
- // Set the home url if it is not set yet
- if secure.HomeURL == "" {
- secure.HomeURL = url
- // Make sure oauth contains our endpoints
- secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm, logger)
- } else { // Else just pass in the fsm and logger
- secure.OAuth.Update(fsm, logger)
+// Initializes the home server and adds its own location
+func (secure *SecureInternetHomeServer) init(homeOrg *types.DiscoveryOrganization, homeLocation *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) error {
+ errorMessage := "failed initializing secure internet home server"
+
+ if secure.HomeOrganizationID != homeOrg.OrgId {
+ // New home organisation, clear everything
+ *secure = *&SecureInternetHomeServer{}
+ }
+
+
+ base, baseErr := secure.addLocation(homeLocation, fsm, logger)
+
+ if baseErr != nil {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: baseErr}
}
- // Set the current url
- secure.CurrentURL = url
+ // Make sure to set the organization ID
+ secure.HomeOrganizationID = homeOrg.OrgId
+
+ // Make sure to set the authorization URL template
+ secure.AuthorizationTemplate = homeLocation.AuthenticationURLTemplate
+ // Make sure oauth contains our endpoints
+ secure.OAuth.Init(base.Endpoints.API.V3.Authorization, base.Endpoints.API.V3.Token, fsm, logger)
return nil
}
@@ -211,7 +245,7 @@ func ShouldRenewButton(server Server) (bool, error) {
}
func Login(server Server) error {
- return server.GetOAuth().Login("org.eduvpn.app.linux")
+ return server.GetOAuth().Login("org.eduvpn.app.linux", server.GetTemplateAuth())
}
func EnsureTokens(server Server) error {
@@ -240,21 +274,9 @@ func CancelOAuth(server Server) {
server.GetOAuth().Cancel()
}
-func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) {
- errorMessage := "failed ensuring server"
- // Intialize the secure internet server
- // This calls the init method which takes care of the rest
- if isSecureInternet {
- initErr := servers.SecureInternetHomeServer.init(url, fsm, logger)
-
- if initErr != nil {
- return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr}
- }
-
- servers.IsSecureInternet = true
- return &servers.SecureInternetHomeServer, nil
- }
-
+func (servers *Servers) AddInstituteAccess(instituteServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) {
+ url := instituteServer.BaseURL
+ errorMessage := fmt.Sprintf("failed adding institute access server: %s", url)
instituteServers := &servers.InstituteServers
if instituteServers.Map == nil {
@@ -279,6 +301,33 @@ func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm
return institute, nil
}
+func (servers *Servers) SetSecureLocation(chosenLocationServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) error {
+ errorMessage := "failed to set secure location"
+ // Make sure to add the current location
+ _, addLocationErr := servers.SecureInternetHomeServer.addLocation(chosenLocationServer, fsm, logger)
+
+ if addLocationErr != nil {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: addLocationErr}
+ }
+
+ servers.SecureInternetHomeServer.CurrentLocation = chosenLocationServer.CountryCode
+ return nil
+}
+
+func (servers *Servers) AddSecureInternet(secureOrg *types.DiscoveryOrganization, secureServer *types.DiscoveryServer, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) {
+ errorMessage := "failed adding secure internet server"
+ // If we have specified an organization ID
+ // We also need to get an authorization template
+ initErr := servers.SecureInternetHomeServer.init(secureOrg, secureServer, fsm, logger)
+
+ if initErr != nil {
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: initErr}
+ }
+
+ servers.IsSecureInternet = true
+ return &servers.SecureInternetHomeServer, nil
+}
+
type ServerProfile struct {
ID string `json:"profile_id"`
DisplayName string `json:"display_name"`
@@ -554,5 +603,5 @@ type ServerSecureInternetBaseNotFoundError struct {
}
func (e *ServerSecureInternetBaseNotFoundError) Error() string {
- return fmt.Sprintf("secure internet base not found with current: %s", e.Current)
+ return fmt.Sprintf("secure internet base not found with current location: %s", e.Current)
}
diff --git a/internal/types/server.go b/internal/types/server.go
new file mode 100644
index 0000000..ba9b217
--- /dev/null
+++ b/internal/types/server.go
@@ -0,0 +1,42 @@
+package types
+
+// Shared server types
+
+// Structs that define the json format for
+// url: "https://disco.eduvpn.org/v2/organization_list.json"
+type DiscoveryOrganizations struct {
+ Version uint64 `json:"v"`
+ List []DiscoveryOrganization `json:"organization_list"`
+ Timestamp int64 `json:"-"`
+ RawString string `json:"-"`
+}
+
+type DiscoveryOrganization struct {
+ DisplayName struct {
+ En string `json:"en"`
+ } `json:"display_name"`
+ OrgId string `json:"org_id"`
+ SecureInternetHome string `json:"secure_internet_home"`
+ KeywordList struct {
+ En string `json:"en"`
+ } `json:"keyword_list"`
+}
+
+// Structs that define the json format for
+// url: "https://disco.eduvpn.org/v2/server_list.json"
+type DiscoveryServers struct {
+ Version uint64 `json:"v"`
+ List []DiscoveryServer `json:"server_list"`
+ Timestamp int64 `json:"-"`
+ RawString string `json:"-"`
+}
+
+type DiscoveryServer struct {
+ AuthenticationURLTemplate string `json:"authentication_url_template"`
+ BaseURL string `json:"base_url"`
+ CountryCode string `json:"country_code"`
+ PublicKeyList []string `json:"public_key_list"`
+ Type string `json:"server_type"`
+ SupportContact []string `json:"support_contact"`
+}
+
diff --git a/internal/util/util.go b/internal/util/util.go
index 8dee61e..30767c3 100644
--- a/internal/util/util.go
+++ b/internal/util/util.go
@@ -4,8 +4,11 @@ import (
"crypto/rand"
"fmt"
"os"
+ "strings"
"time"
+ "net/url"
+
"github.com/jwijenbergh/eduvpn-common/internal/types"
)
@@ -31,3 +34,23 @@ func EnsureDirectory(directory string) error {
}
return nil
}
+
+// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md
+// URL encode for skipping where are you from (WAYF). Note that this right now is basically an alias to QueryEscape
+func WAYFEncode(input string) string {
+ // QueryReplace already replaces a space with a +
+ // see https://go.dev/play/p/pOfrn-Wsq5
+ return url.QueryEscape(input)
+}
+
+// See https://github.com/eduvpn/documentation/blob/dc4d53c47dd7a69e95d6650eec408e16eaa814a2/SERVER_DISCOVERY_SKIP_WAYF.md
+func ReplaceWAYF(authTemplate string, authURL string, orgID string) string {
+ if authTemplate == "" {
+ return authURL
+ }
+ // Replace authURL
+ authTemplate = strings.Replace(authTemplate, "@RETURN_TO@", WAYFEncode(authURL), 1)
+ // Replace ORG ID
+ authTemplate = strings.Replace(authTemplate, "@ORG_ID@", WAYFEncode(orgID), 1)
+ return authTemplate
+}
diff --git a/state.go b/state.go
index 44d277e..8306ca0 100644
--- a/state.go
+++ b/state.go
@@ -1,6 +1,7 @@
package eduvpn
import (
+ "fmt"
"github.com/jwijenbergh/eduvpn-common/internal/config"
"github.com/jwijenbergh/eduvpn-common/internal/discovery"
"github.com/jwijenbergh/eduvpn-common/internal/fsm"
@@ -77,6 +78,9 @@ func (state *VPNState) Register(name string, directory string, stateCallback fun
// Go to the No Server state with the saved servers
state.FSM.GoTransitionWithData(fsm.NO_SERVER, state.GetSavedServers(), false)
+
+ state.GetDiscoServers()
+ state.GetDiscoOrganizations()
return nil
}
@@ -92,51 +96,12 @@ func (state *VPNState) Deregister() error {
return nil
}
-func (state *VPNState) CancelOAuth() error {
- errorMessage := "failed to cancel OAuth"
- if !state.FSM.InState(fsm.OAUTH_STARTED) {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: state.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()}
- }
-
- currentServer, serverErr := state.Servers.GetCurrentServer()
-
- if serverErr != nil {
- return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
- }
- server.CancelOAuth(currentServer)
- return nil
-}
-
-func (state *VPNState) chooseServer(url string, isSecureInternet bool) (server.Server, error) {
- // New server chosen, ensure the server is fresh
- server, serverErr := state.Servers.EnsureServer(url, isSecureInternet, &state.FSM, &state.Logger)
-
- if serverErr != nil {
- return nil, &types.WrappedErrorMessage{Message: "failed to choose server", Err: serverErr}
- }
-
- // Make sure we are in the chosen state if available
- state.FSM.GoTransition(fsm.CHOSEN_SERVER)
- return server, nil
-}
-
-func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, forceTCP bool) (string, string, error) {
+func (state *VPNState) getConfig(chosenServer server.Server, forceTCP bool) (string, string, error) {
errorMessage := "failed to get a configuration for OpenVPN/Wireguard"
if state.FSM.InState(fsm.DEREGISTERED) {
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.DeregisteredError{}.CustomError()}
}
- // Go to no server if possible, else return an error
- if !state.FSM.InState(fsm.NO_SERVER) && !state.FSM.GoTransition(fsm.NO_SERVER) {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: state.FSM.Current, Want: fsm.NO_SERVER}.CustomError()}
- }
-
- // Make sure the server is chosen
- chosenServer, serverErr := state.chooseServer(url, isSecureInternet)
-
- if serverErr != nil {
- return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
- }
// Relogin with oauth
// This moves the state to authorized
if server.NeedsRelogin(chosenServer) {
@@ -167,14 +132,134 @@ func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, f
return config, configType, nil
}
+func (state *VPNState) AskSecureLocation() error {
+ fmt.Println("locations: ", state.Discovery.GetSecureLocationList())
+ server, serverErr := state.Discovery.GetServerByCountryCode("NL", "secure_internet")
+
+ if serverErr != nil {
+ return &types.WrappedErrorMessage{Message: "failed asking secure location", Err: serverErr}
+ }
+
+ state.Servers.SetSecureLocation(server, &state.FSM, &state.Logger)
+
+ return nil
+}
+
+
+func (state *VPNState) addSecureInternetHomeServer(orgID string) (server.Server, error) {
+ errorMessage := fmt.Sprintf("failed adding Secure Internet home server with organization ID %s", orgID)
+ // Get the secure internet URL from discovery
+ secureOrg, secureServer, discoErr := state.Discovery.GetSecureHomeArgs(orgID)
+ if discoErr != nil {
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr}
+ }
+
+ // Add the secure internet server
+ server, serverErr := state.Servers.AddSecureInternet(secureOrg, secureServer, &state.FSM, &state.Logger)
+
+ if serverErr != nil {
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
+ }
+
+ if !state.Servers.HasSecureLocation() {
+ locationErr := state.AskSecureLocation()
+
+ if locationErr != nil {
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: locationErr}
+ }
+ }
+
+ return server, nil
+}
+
+func (state *VPNState) GetConfigSecureInternet(orgID string, forceTCP bool) (string, string, error) {
+ errorMessage := fmt.Sprintf("failed getting a configuration for Secure Internet organization %s", orgID)
+ server, serverErr := state.addSecureInternetHomeServer(orgID)
+
+ if serverErr != nil {
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
+ }
+
+ state.FSM.GoTransition(fsm.CHOSEN_SERVER)
+
+ return state.getConfig(server, forceTCP)
+}
+
+func (state *VPNState) addInstituteServer(url string) (server.Server, error) {
+ errorMessage := fmt.Sprintf("failed adding Institute Access server with url %s", url)
+ instituteServer, discoErr := state.Discovery.GetServerByURL(url, "institute_access")
+ if discoErr != nil {
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: discoErr}
+ }
+ // Add the secure internet server
+ server, serverErr := state.Servers.AddInstituteAccess(instituteServer, &state.FSM, &state.Logger)
+
+ if serverErr != nil {
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
+ }
+
+ state.FSM.GoTransition(fsm.CHOSEN_SERVER)
+
+ return server, nil
+
+}
+
+func (state *VPNState) addCustomServer(url string) (server.Server, error) {
+ errorMessage := fmt.Sprintf("failed adding Custom server with url %s", url)
+
+ instituteServer := &types.DiscoveryServer{BaseURL: url, CountryCode: "NL", Type: "custom", SupportContact: []string{"custom"}}
+
+ // A custom server is just an institute access server under the hood
+ server, serverErr := state.Servers.AddInstituteAccess(instituteServer, &state.FSM, &state.Logger)
+
+ if serverErr != nil {
+ return nil, &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
+ }
+
+ state.FSM.GoTransition(fsm.CHOSEN_SERVER)
+
+ return server, nil
+
+}
+
func (state *VPNState) GetConfigInstituteAccess(url string, forceTCP bool) (string, string, error) {
- return state.getConfigWithOptions(url, false, forceTCP)
+ errorMessage := fmt.Sprintf("failed getting a configuration for Institute Access %s", url)
+ server, serverErr := state.addInstituteServer(url)
+
+ if serverErr != nil {
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
+ }
+
+ return state.getConfig(server, forceTCP)
}
-func (state *VPNState) GetConfigSecureInternet(url string, forceTCP bool) (string, string, error) {
- return state.getConfigWithOptions(url, true, forceTCP)
+func (state *VPNState) GetConfigCustomServer(url string, forceTCP bool) (string, string, error) {
+ errorMessage := fmt.Sprintf("failed getting a configuration for custom server %s", url)
+ server, serverErr := state.addCustomServer(url)
+
+ if serverErr != nil {
+ return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
+ }
+
+ return state.getConfig(server, forceTCP)
}
+func (state *VPNState) CancelOAuth() error {
+ errorMessage := "failed to cancel OAuth"
+ if !state.FSM.InState(fsm.OAUTH_STARTED) {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: state.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()}
+ }
+
+ currentServer, serverErr := state.Servers.GetCurrentServer()
+
+ if serverErr != nil {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: serverErr}
+ }
+ server.CancelOAuth(currentServer)
+ return nil
+}
+
+
func (state *VPNState) GetDiscoOrganizations() (string, error) {
if state.FSM.InState(fsm.DEREGISTERED) {
return "", &types.WrappedErrorMessage{Message: "failed to get the organizations with Discovery", Err: fsm.DeregisteredError{}.CustomError()}
diff --git a/state_test.go b/state_test.go
index 76f4763..c42b314 100644
--- a/state_test.go
+++ b/state_test.go
@@ -72,7 +72,7 @@ func Test_server(t *testing.T) {
stateCallback(t, old, new, data, state)
}, false)
- _, _, configErr := state.GetConfigInstituteAccess(serverURI, false)
+ _, _, configErr := state.GetConfigCustomServer(serverURI, false)
if configErr != nil {
t.Fatalf("Connect error: %v", configErr)
@@ -95,7 +95,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters httpw.URLParameters,
}
}, false)
- _, _, configErr := state.GetConfigInstituteAccess(serverURI, false)
+ _, _, configErr := state.GetConfigCustomServer(serverURI, false)
var wrappedErr *types.WrappedErrorMessage
@@ -157,7 +157,7 @@ func Test_token_expired(t *testing.T) {
stateCallback(t, old, new, data, state)
}, false)
- _, _, configErr := state.GetConfigInstituteAccess(serverURI, false)
+ _, _, configErr := state.GetConfigCustomServer(serverURI, false)
if configErr != nil {
t.Fatalf("Connect error before expired: %v", configErr)
@@ -205,7 +205,7 @@ func Test_token_invalid(t *testing.T) {
stateCallback(t, old, new, data, state)
}, false)
- _, _, configErr := state.GetConfigInstituteAccess(serverURI, false)
+ _, _, configErr := state.GetConfigCustomServer(serverURI, false)
if configErr != nil {
t.Fatalf("Connect error before invalid: %v", configErr)
@@ -229,7 +229,7 @@ func Test_token_invalid(t *testing.T) {
oauth.Token.Access = dummy_value
oauth.Token.Refresh = dummy_value
- _, _, configErr = state.GetConfigInstituteAccess(serverURI, false)
+ _, _, configErr = state.GetConfigCustomServer(serverURI, false)
if configErr != nil {
t.Fatalf("Connect error after invalid: %v", configErr)
@@ -255,7 +255,7 @@ func Test_invalid_profile_corrected(t *testing.T) {
stateCallback(t, old, new, data, state)
}, false)
- _, _, configErr := state.GetConfigInstituteAccess(serverURI, false)
+ _, _, configErr := state.GetConfigCustomServer(serverURI, false)
if configErr != nil {
t.Fatalf("First connect error: %v", configErr)
@@ -274,7 +274,7 @@ func Test_invalid_profile_corrected(t *testing.T) {
previousProfile := base.Profiles.Current
base.Profiles.Current = "IDONOTEXIST"
- _, _, configErr = state.GetConfigInstituteAccess(serverURI, false)
+ _, _, configErr = state.GetConfigCustomServer(serverURI, false)
if configErr != nil {
t.Fatalf("Second connect error: %v", configErr)
diff --git a/wrappers/python/main.py b/wrappers/python/main.py
index a94281a..5422d93 100644
--- a/wrappers/python/main.py
+++ b/wrappers/python/main.py
@@ -68,7 +68,7 @@ if __name__ == "__main__":
print("Failed registering:", e)
server = input(
- "Which Institute Access server do you want to connect to? (e.g. https://eduvpn.example.com): "
+ "Which server (Custom/Institute Access) do you want to connect to? (e.g. https://eduvpn.example.com): "
)
# Ensure we have a valid http prefix
@@ -78,7 +78,7 @@ if __name__ == "__main__":
# Get a Wireguard/OpenVPN config
try:
- config, config_type = _eduvpn.get_config_institute_access(server)
+ config, config_type = _eduvpn.get_config_custom_server(server)
except Exception as e:
print("Failed to connect:", e)
print(f"Got a config with type: {config_type} and contents:\n{config}")
diff --git a/wrappers/python/src/__init__.py b/wrappers/python/src/__init__.py
index 1ec0bec..f2ae66e 100644
--- a/wrappers/python/src/__init__.py
+++ b/wrappers/python/src/__init__.py
@@ -47,10 +47,19 @@ VPNStateChange = CFUNCTYPE(None, c_char_p, c_char_p, c_char_p, c_char_p)
# Exposed functions
# We have to use c_void_p instead of c_char_p to free it properly
# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error
-lib.GetConnectConfig.argtypes, lib.GetConnectConfig.restype = [
+lib.GetConfigSecureInternet.argtypes, lib.GetConfigSecureInternet.restype = [
c_char_p,
c_char_p,
c_int,
+], MultipleDataError
+lib.GetConfigInstituteAccess.argtypes, lib.GetConfigInstituteAccess.restype = [
+ c_char_p,
+ c_char_p,
+ c_int,
+], MultipleDataError
+lib.GetConfigCustomServer.argtypes, lib.GetConfigCustomServer.restype = [
+ c_char_p,
+ c_char_p,
c_int,
], MultipleDataError
lib.Deregister.argtypes, lib.Deregister.restype = [c_char_p], c_void_p
diff --git a/wrappers/python/src/main.py b/wrappers/python/src/main.py
index 76a08ab..dda3250 100644
--- a/wrappers/python/src/main.py
+++ b/wrappers/python/src/main.py
@@ -1,4 +1,5 @@
from . import lib, VPNStateChange, encode_args, decode_res
+from enum import Enum
from typing import Optional, Tuple
import threading
from .event import StateType, EventHandler
@@ -90,14 +91,15 @@ class EduVPN(object):
return organizations
def get_config(
- self, url: str, is_secure_internet: bool = False, force_tcp: bool = False
+ self, url: str, func: callable, force_tcp: bool = False
):
# Because it could be the case that a profile callback is started, store a threading event
# In the constructor, we have defined a wait event for Ask_Profile, this waits for this event to be set
# The event is set in self.set_profile
self.profile_event = threading.Event()
+
config, config_type, config_err = self.go_function(
- lib.GetConnectConfig, url, is_secure_internet, force_tcp
+ func, url, force_tcp
)
if config_err:
@@ -107,15 +109,20 @@ class EduVPN(object):
return config, config_type
+ def get_config_custom_server(
+ self, url: str, force_tcp: bool = False
+ ) -> Tuple[str, str]:
+ return self.get_config(url, lib.GetConfigCustomServer, force_tcp)
+
def get_config_institute_access(
self, url: str, force_tcp: bool = False
) -> Tuple[str, str]:
- return self.get_config(url, False, force_tcp)
+ return self.get_config(url, lib.GetConfigInstituteAccess, force_tcp)
def get_config_secure_internet(
self, url: str, force_tcp: bool = False
) -> Tuple[str, str]:
- return self.get_config(url, True, force_tcp)
+ return self.get_config(url, lib.GetConfigSecureInternet, force_tcp)
def set_connected(self) -> None:
connect_err = self.go_function(lib.SetConnected)
diff --git a/wrappers/python/tests.py b/wrappers/python/tests.py
index 7f17ef6..e58e67d 100644
--- a/wrappers/python/tests.py
+++ b/wrappers/python/tests.py
@@ -29,7 +29,7 @@ class ConfigTests(unittest.TestCase):
self.fail("No SERVER_URI environment variable given")
# This can throw an exception
- _eduvpn.get_config_institute_access(server_uri)
+ _eduvpn.get_config_custom_server(server_uri)
# Deregister
_eduvpn.deregister()