diff options
| -rw-r--r-- | cmd/cli/main.go | 196 | ||||
| -rw-r--r-- | exports/exports.go | 2 | ||||
| -rw-r--r-- | internal/discovery.go | 4 | ||||
| -rw-r--r-- | internal/fsm.go | 6 | ||||
| -rw-r--r-- | internal/oauth.go | 3 | ||||
| -rw-r--r-- | internal/server.go | 45 | ||||
| -rw-r--r-- | state.go | 27 | ||||
| -rw-r--r-- | state_test.go | 8 |
8 files changed, 254 insertions, 37 deletions
diff --git a/cmd/cli/main.go b/cmd/cli/main.go index c31716f..17348c9 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -1,9 +1,13 @@ package main import ( + "encoding/json" "flag" "fmt" + "os" "os/exec" + "path" + "path/filepath" "strings" eduvpn "github.com/jwijenbergh/eduvpn-common" @@ -15,39 +19,197 @@ func openBrowser(urlString string) { exec.Command("xdg-open", urlString).Start() } -func logState(oldState string, newState string, data string) { - fmt.Printf("State: %s -> State: %s with data %s\n", oldState, newState, data) +// Taken from internal/server.go as it's an internal API for now +type ServerProfile struct { + ID string `json:"profile_id"` + DisplayName string `json:"display_name"` + VPNProtoList []string `json:"vpn_proto_list"` + DefaultGateway bool `json:"default_gateway"` +} + +type ServerProfileInfo struct { + Current string `json:"current_profile"` + Info struct { + ProfileList []ServerProfile `json:"profile_list"` + } `json:"info"` +} + +func sendProfile(state *eduvpn.VPNState, data string) { + fmt.Printf("Multiple VPN profiles found. Please select a profile by entering e.g. 1") + serverProfiles := &ServerProfileInfo{} + + jsonErr := json.Unmarshal([]byte(data), &serverProfiles) + if jsonErr != nil { + fmt.Println("\nFailed to get profile list", jsonErr) + return + } + + var profiles string + + for index, profile := range serverProfiles.Info.ProfileList { + profiles += fmt.Sprintf("\n%d - %s", index+1, profile.DisplayName) + } + + // Show the profiles + fmt.Println(profiles) + + var chosenProfile int + _, scanErr := fmt.Scanf("%d", &chosenProfile) + + if scanErr != nil || chosenProfile <= 0 || chosenProfile > len(serverProfiles.Info.ProfileList) { + fmt.Println("invalid profile chosen, please retry") + sendProfile(state, data) + return + } + + profile := serverProfiles.Info.ProfileList[chosenProfile-1] + fmt.Println("Sending profile ID", profile.ID) + profileErr := state.SetProfileID(profile.ID) + + if profileErr != nil { + fmt.Println("Failed setting profile with error", profileErr) + } +} + +func stateCallback(state *eduvpn.VPNState, oldState string, newState string, data string) { if newState == "OAuth_Started" { openBrowser(data) } + + if newState == "Ask_Profile" { + sendProfile(state, data) + } } -func main() { - urlArg := flag.String("url", "", "The url of the vpn") - flag.Parse() +func getConfig(url string, isInstitute bool) (string, error) { + if !strings.HasPrefix(url, "https://") { + url = "https://" + url + } + state := &eduvpn.VPNState{} + state.Register("org.eduvpn.app.linux", "configs", func(old string, new string, data string) { + stateCallback(state, old, new, data) + }, false) - urlString := *urlArg + defer state.Deregister() - if urlString != "" { - if !strings.HasPrefix(urlString, "https://") { - urlString = "https://" + urlString + if isInstitute { + return state.ConnectInstituteAccess(url) + } + return state.ConnectSecureInternet(url) +} + +type ServerDiscoEntry struct { + ServerType string `json:"server_type"` + BaseURL string `json:"base_url"` +} + +func getAllSecureInternetServers(serverList string) ([]string, error) { + var secureInternet []string + + discoEntries := []ServerDiscoEntry{} + + jsonErr := json.Unmarshal([]byte(serverList), &discoEntries) + + if jsonErr != nil { + return nil, jsonErr + } + + for _, entry := range discoEntries { + if entry.ServerType == "secure_internet" { + secureInternet = append(secureInternet, entry.BaseURL) } + } + + return secureInternet, nil +} + +func storeSecureInternetConfig(state *eduvpn.VPNState, url string, directory string) { + os.MkdirAll(directory, os.ModePerm) + + fmt.Println("Creating and storing cert for", url) + + config, configErr := getConfig(url, false) + + if configErr != nil { + fmt.Printf("Failed obtaining config for url %s with error %v\n", url, configErr) + } + + cleanURL := filepath.Base(url) + + writeErr := os.WriteFile(path.Join(directory, cleanURL), []byte(config), 0o644) + if writeErr != nil { + fmt.Printf("Failed writing config for url %s with error %v\n", url, writeErr) + } +} + +func getSecureInternetAll(homeURL string) { + state := &eduvpn.VPNState{} + + state.Register("org.eduvpn.app.linux", "configs", func(old string, new string, data string) { + stateCallback(state, old, new, data) + }, false) + + // Get the disco servers + servers, serversErr := state.GetDiscoServers() + + if serversErr != nil { + fmt.Println("Cannot obtain servers", serversErr) + return + } - state := &eduvpn.VPNState{} + secureInternetURLs, secureInternetErr := getAllSecureInternetServers(servers) - state.Register("org.eduvpn.app.linux", "configs", logState, true) - config, configErr := state.Connect(urlString) + if secureInternetErr != nil { + fmt.Println("Cannot parse secure internet servers", secureInternetErr) + return + } + + // Ensure that the directory exists + directory := "certs" + os.MkdirAll(directory, os.ModePerm) - if configErr != nil { - fmt.Printf("Config error %v", configErr) - return + // Obtain config for home server + storeSecureInternetConfig(state, homeURL, directory) + + for _, serverURL := range secureInternetURLs { + if !strings.Contains(serverURL, homeURL) { + storeSecureInternetConfig(state, serverURL, directory) } + } + + fmt.Println("Done storing all certs in directory:", directory) +} + +func printConfig(url string, isInstitute bool) { + config, configErr := getConfig(url, isInstitute) + + if configErr != nil { + fmt.Println("Error getting config", configErr) + return + } - fmt.Println(config) + fmt.Println("Obtained config", config) +} - state.Deregister() +func main() { + urlArg := flag.String("get-institute", "", "The url of an institute to connect to") + secureInternet := flag.String("get-secure", "", "Gets secure internet servers.") + secureInternetAll := flag.String("get-secure-all", "", "Gets certificates for all secure internet servers. It stores them in ./certs. Provide an URL for the home server e.g. nl.eduvpn.org.") + flag.Parse() + // Connect to a VPN by getting an Institute Access config + urlString := *urlArg + secureInternetString := *secureInternet + secureInternetAllString := *secureInternetAll + if urlString != "" { + printConfig(urlString, true) + return + } else if secureInternetString != "" { + printConfig(secureInternetString, false) + return + } else if secureInternetAllString != "" { + getSecureInternetAll(secureInternetAllString) return } diff --git a/exports/exports.go b/exports/exports.go index f21a354..576e980 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -103,7 +103,7 @@ func Connect(name *C.char, url *C.char) (*C.char, *C.char) { if stateErr != nil { return nil, C.CString(ErrorToString(stateErr)) } - config, configErr := state.Connect(C.GoString(url)) + config, configErr := state.ConnectInstituteAccess(C.GoString(url)) return C.CString(config), C.CString(ErrorToString(configErr)) } diff --git a/internal/discovery.go b/internal/discovery.go index 8c0acc7..59281bd 100644 --- a/internal/discovery.go +++ b/internal/discovery.go @@ -57,8 +57,8 @@ type ServersList struct { type Discovery struct { Organizations OrganizationList Servers ServersList - FSM *FSM - Logger *FileLogger + FSM *FSM + Logger *FileLogger } // Helper function that gets a disco json diff --git a/internal/fsm.go b/internal/fsm.go index 6997d92..1bcc479 100644 --- a/internal/fsm.go +++ b/internal/fsm.go @@ -4,8 +4,8 @@ import ( "fmt" "os" "os/exec" - "sort" "path" + "sort" ) type ( @@ -94,7 +94,7 @@ type FSM struct { Current FSMStateID // Info to be passed from the parent state - Name string + Name string StateCallback func(string, string, string) Logger *FileLogger Directory string @@ -107,7 +107,7 @@ func (fsm *FSM) Init(name string, callback func(string, string, string), logger NO_SERVER: {{CHOSEN_SERVER, "User chooses a server"}}, CHOSEN_SERVER: {{AUTHORIZED, "Found tokens in config"}, {OAUTH_STARTED, "No tokens found in config"}}, OAUTH_STARTED: {{AUTHORIZED, "User authorizes with browser"}, {CHOSEN_SERVER, "Cancel OAuth"}}, - AUTHORIZED: {{OAUTH_STARTED, "Re-authorize with OAuth"}, {REQUEST_CONFIG, "Client requests a config"}}, + AUTHORIZED: {{OAUTH_STARTED, "Re-authorize with OAuth"}, {REQUEST_CONFIG, "Client requests a config"}}, REQUEST_CONFIG: {{ASK_PROFILE, "Multiple profiles found"}, {HAS_CONFIG, "Success, only one profile"}}, ASK_PROFILE: {{HAS_CONFIG, "User chooses profile and success"}}, HAS_CONFIG: {{CONNECTED, "OS reports connected"}}, diff --git a/internal/oauth.go b/internal/oauth.go index 9d17777..98af5a4 100644 --- a/internal/oauth.go +++ b/internal/oauth.go @@ -326,8 +326,7 @@ func (oauth *OAuth) NeedsRelogin() bool { return true } -type OAuthCancelledCallbackError struct { -} +type OAuthCancelledCallbackError struct{} func (e *OAuthCancelledCallbackError) Error() string { return fmt.Sprintf("Client cancelled OAuth") diff --git a/internal/server.go b/internal/server.go index aa21a97..1d6f1e1 100644 --- a/internal/server.go +++ b/internal/server.go @@ -17,8 +17,9 @@ type Server struct { } type Servers struct { - List map[string]*Server `json:"list"` - Current string `json:"current"` + List map[string]*Server `json:"list"` + Current string `json:"current"` + SecureHome string `json:"secure_home"` } func (servers *Servers) GetCurrentServer() (*Server, error) { @@ -57,7 +58,10 @@ func (server *Server) EnsureTokens() error { return nil } -func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger) (*Server, error) { +func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger, makeCurrent bool) (*Server, error) { + if url == "" { + return nil, errors.New("Emtpy URL to ensure Server") + } if servers.List == nil { servers.List = make(map[string]*Server) } @@ -73,10 +77,41 @@ func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger) ( return nil, serverInitErr } servers.List[url] = server - servers.Current = url + + if makeCurrent { + servers.Current = url + } + return server, nil +} + +func (servers *Servers) getSecureInternetHome() (*Server, error) { + server, exists := servers.List[servers.SecureHome] + + if !exists || server == nil { + return nil, errors.New("No secure internet home found") + } + return server, nil } +func (servers *Servers) EnsureSecureHome(server *Server) { + if servers.SecureHome == "" { + servers.SecureHome = server.BaseURL + } +} + +func (servers *Servers) CopySecureInternetOAuth(server *Server) error { + secureHome, secureHomeErr := servers.getSecureInternetHome() + + if secureHomeErr != nil { + return secureHomeErr + } + + // Forward token properties + server.OAuth = secureHome.OAuth + return nil +} + type ServerProfile struct { ID string `json:"profile_id"` DisplayName string `json:"display_name"` @@ -151,7 +186,7 @@ func (server *Server) getCurrentProfile() (*ServerProfile, error) { return &profile, nil } } - return nil, errors.New("no profile found for id") + return nil, errors.New(fmt.Sprintf("no profile found for id %s", profile_id)) } func (server *Server) getConfigWithProfile() (string, error) { @@ -2,6 +2,7 @@ package eduvpn import ( "errors" + "github.com/jwijenbergh/eduvpn-common/internal" ) @@ -86,16 +87,26 @@ func (state *VPNState) CancelOAuth() error { return nil } -func (state *VPNState) Connect(url string) (string, error) { +func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (string, error) { if state.FSM.InState(internal.DEREGISTERED) { return "", errors.New("app not registered") } // New server chosen, ensure the server is fresh - server, serverErr := state.Servers.EnsureServer(url, &state.FSM, &state.Logger) + server, serverErr := state.Servers.EnsureServer(url, &state.FSM, &state.Logger, true) if serverErr != nil { return "", serverErr } + + // When we connect to secure internet, copy over the tokens from the home server + if isSecureInternet { + // Ensure the secure home server + state.Servers.EnsureServer(state.Servers.SecureHome, &state.FSM, &state.Logger, false) + + // Copy the tokens + state.Servers.CopySecureInternetOAuth(server) + } + // Make sure we are in the chosen state if available state.FSM.GoTransition(internal.CHOSEN_SERVER) // Relogin with oauth @@ -113,6 +124,9 @@ func (state *VPNState) Connect(url string) (string, error) { state.FSM.GoTransition(internal.AUTHORIZED) } + // Set the home server if it is not set already + state.Servers.EnsureSecureHome(server) + state.FSM.GoTransition(internal.REQUEST_CONFIG) config, configErr := server.GetConfig() @@ -126,6 +140,14 @@ func (state *VPNState) Connect(url string) (string, error) { return config, nil } +func (state *VPNState) ConnectInstituteAccess(url string) (string, error) { + return state.connectWithOptions(url, false) +} + +func (state *VPNState) ConnectSecureInternet(url string) (string, error) { + return state.connectWithOptions(url, true) +} + func (state *VPNState) GetDiscoOrganizations() (string, error) { if state.FSM.InState(internal.DEREGISTERED) { return "", errors.New("app not registered") @@ -133,7 +155,6 @@ func (state *VPNState) GetDiscoOrganizations() (string, error) { return state.Discovery.GetOrganizationsList() } - func (state *VPNState) GetDiscoServers() (string, error) { if state.FSM.InState(internal.DEREGISTERED) { return "", errors.New("app not registered") diff --git a/state_test.go b/state_test.go index 5f37147..4320a6d 100644 --- a/state_test.go +++ b/state_test.go @@ -59,7 +59,7 @@ func Test_server(t *testing.T) { stateCallback(t, old, new, data, state) }, false) - _, configErr := state.Connect(serverURI) + _, configErr := state.ConnectInstituteAccess(serverURI) if configErr != nil { t.Errorf("Connect error: %v", configErr) @@ -82,7 +82,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters internal.URLParameter } }, false) - _, configErr := state.Connect(serverURI) + _, configErr := state.ConnectInstituteAccess(serverURI) if !errors.As(configErr, expectedErr) { t.Errorf("error %T = %v, wantErr %T", configErr, configErr, expectedErr) @@ -130,7 +130,7 @@ func Test_token_expired(t *testing.T) { stateCallback(t, old, new, data, state) }, false) - _, configErr := state.Connect(serverURI) + _, configErr := state.ConnectInstituteAccess(serverURI) if configErr != nil { t.Errorf("Connect error before expired: %v", configErr) @@ -174,7 +174,7 @@ func Test_token_invalid(t *testing.T) { stateCallback(t, old, new, data, state) }, false) - _, configErr := state.Connect(serverURI) + _, configErr := state.ConnectInstituteAccess(serverURI) if configErr != nil { t.Errorf("Connect error before invalid: %v", configErr) |
