summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeroen Wijenbergh <jeroenwijenbergh@protonmail.com>2022-04-29 15:08:32 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-04-29 15:08:32 +0200
commit0e1f9826f2aea1a059529f9c3d1c921d7d4ac3d4 (patch)
tree2d26bd6dbd33abde910bff00078f520dad890a4d
parent6c7a1c7a9245cf457a86fd15bdc14bc93b55d508 (diff)
Secure Internet: Basic implementation and add support to cli
-rw-r--r--cmd/cli/main.go196
-rw-r--r--exports/exports.go2
-rw-r--r--internal/discovery.go4
-rw-r--r--internal/fsm.go6
-rw-r--r--internal/oauth.go3
-rw-r--r--internal/server.go45
-rw-r--r--state.go27
-rw-r--r--state_test.go8
8 files changed, 254 insertions, 37 deletions
diff --git a/cmd/cli/main.go b/cmd/cli/main.go
index c31716f..17348c9 100644
--- a/cmd/cli/main.go
+++ b/cmd/cli/main.go
@@ -1,9 +1,13 @@
package main
import (
+ "encoding/json"
"flag"
"fmt"
+ "os"
"os/exec"
+ "path"
+ "path/filepath"
"strings"
eduvpn "github.com/jwijenbergh/eduvpn-common"
@@ -15,39 +19,197 @@ func openBrowser(urlString string) {
exec.Command("xdg-open", urlString).Start()
}
-func logState(oldState string, newState string, data string) {
- fmt.Printf("State: %s -> State: %s with data %s\n", oldState, newState, data)
+// Taken from internal/server.go as it's an internal API for now
+type ServerProfile struct {
+ ID string `json:"profile_id"`
+ DisplayName string `json:"display_name"`
+ VPNProtoList []string `json:"vpn_proto_list"`
+ DefaultGateway bool `json:"default_gateway"`
+}
+
+type ServerProfileInfo struct {
+ Current string `json:"current_profile"`
+ Info struct {
+ ProfileList []ServerProfile `json:"profile_list"`
+ } `json:"info"`
+}
+
+func sendProfile(state *eduvpn.VPNState, data string) {
+ fmt.Printf("Multiple VPN profiles found. Please select a profile by entering e.g. 1")
+ serverProfiles := &ServerProfileInfo{}
+
+ jsonErr := json.Unmarshal([]byte(data), &serverProfiles)
+ if jsonErr != nil {
+ fmt.Println("\nFailed to get profile list", jsonErr)
+ return
+ }
+
+ var profiles string
+
+ for index, profile := range serverProfiles.Info.ProfileList {
+ profiles += fmt.Sprintf("\n%d - %s", index+1, profile.DisplayName)
+ }
+
+ // Show the profiles
+ fmt.Println(profiles)
+
+ var chosenProfile int
+ _, scanErr := fmt.Scanf("%d", &chosenProfile)
+
+ if scanErr != nil || chosenProfile <= 0 || chosenProfile > len(serverProfiles.Info.ProfileList) {
+ fmt.Println("invalid profile chosen, please retry")
+ sendProfile(state, data)
+ return
+ }
+
+ profile := serverProfiles.Info.ProfileList[chosenProfile-1]
+ fmt.Println("Sending profile ID", profile.ID)
+ profileErr := state.SetProfileID(profile.ID)
+
+ if profileErr != nil {
+ fmt.Println("Failed setting profile with error", profileErr)
+ }
+}
+
+func stateCallback(state *eduvpn.VPNState, oldState string, newState string, data string) {
if newState == "OAuth_Started" {
openBrowser(data)
}
+
+ if newState == "Ask_Profile" {
+ sendProfile(state, data)
+ }
}
-func main() {
- urlArg := flag.String("url", "", "The url of the vpn")
- flag.Parse()
+func getConfig(url string, isInstitute bool) (string, error) {
+ if !strings.HasPrefix(url, "https://") {
+ url = "https://" + url
+ }
+ state := &eduvpn.VPNState{}
+ state.Register("org.eduvpn.app.linux", "configs", func(old string, new string, data string) {
+ stateCallback(state, old, new, data)
+ }, false)
- urlString := *urlArg
+ defer state.Deregister()
- if urlString != "" {
- if !strings.HasPrefix(urlString, "https://") {
- urlString = "https://" + urlString
+ if isInstitute {
+ return state.ConnectInstituteAccess(url)
+ }
+ return state.ConnectSecureInternet(url)
+}
+
+type ServerDiscoEntry struct {
+ ServerType string `json:"server_type"`
+ BaseURL string `json:"base_url"`
+}
+
+func getAllSecureInternetServers(serverList string) ([]string, error) {
+ var secureInternet []string
+
+ discoEntries := []ServerDiscoEntry{}
+
+ jsonErr := json.Unmarshal([]byte(serverList), &discoEntries)
+
+ if jsonErr != nil {
+ return nil, jsonErr
+ }
+
+ for _, entry := range discoEntries {
+ if entry.ServerType == "secure_internet" {
+ secureInternet = append(secureInternet, entry.BaseURL)
}
+ }
+
+ return secureInternet, nil
+}
+
+func storeSecureInternetConfig(state *eduvpn.VPNState, url string, directory string) {
+ os.MkdirAll(directory, os.ModePerm)
+
+ fmt.Println("Creating and storing cert for", url)
+
+ config, configErr := getConfig(url, false)
+
+ if configErr != nil {
+ fmt.Printf("Failed obtaining config for url %s with error %v\n", url, configErr)
+ }
+
+ cleanURL := filepath.Base(url)
+
+ writeErr := os.WriteFile(path.Join(directory, cleanURL), []byte(config), 0o644)
+ if writeErr != nil {
+ fmt.Printf("Failed writing config for url %s with error %v\n", url, writeErr)
+ }
+}
+
+func getSecureInternetAll(homeURL string) {
+ state := &eduvpn.VPNState{}
+
+ state.Register("org.eduvpn.app.linux", "configs", func(old string, new string, data string) {
+ stateCallback(state, old, new, data)
+ }, false)
+
+ // Get the disco servers
+ servers, serversErr := state.GetDiscoServers()
+
+ if serversErr != nil {
+ fmt.Println("Cannot obtain servers", serversErr)
+ return
+ }
- state := &eduvpn.VPNState{}
+ secureInternetURLs, secureInternetErr := getAllSecureInternetServers(servers)
- state.Register("org.eduvpn.app.linux", "configs", logState, true)
- config, configErr := state.Connect(urlString)
+ if secureInternetErr != nil {
+ fmt.Println("Cannot parse secure internet servers", secureInternetErr)
+ return
+ }
+
+ // Ensure that the directory exists
+ directory := "certs"
+ os.MkdirAll(directory, os.ModePerm)
- if configErr != nil {
- fmt.Printf("Config error %v", configErr)
- return
+ // Obtain config for home server
+ storeSecureInternetConfig(state, homeURL, directory)
+
+ for _, serverURL := range secureInternetURLs {
+ if !strings.Contains(serverURL, homeURL) {
+ storeSecureInternetConfig(state, serverURL, directory)
}
+ }
+
+ fmt.Println("Done storing all certs in directory:", directory)
+}
+
+func printConfig(url string, isInstitute bool) {
+ config, configErr := getConfig(url, isInstitute)
+
+ if configErr != nil {
+ fmt.Println("Error getting config", configErr)
+ return
+ }
- fmt.Println(config)
+ fmt.Println("Obtained config", config)
+}
- state.Deregister()
+func main() {
+ urlArg := flag.String("get-institute", "", "The url of an institute to connect to")
+ secureInternet := flag.String("get-secure", "", "Gets secure internet servers.")
+ secureInternetAll := flag.String("get-secure-all", "", "Gets certificates for all secure internet servers. It stores them in ./certs. Provide an URL for the home server e.g. nl.eduvpn.org.")
+ flag.Parse()
+ // Connect to a VPN by getting an Institute Access config
+ urlString := *urlArg
+ secureInternetString := *secureInternet
+ secureInternetAllString := *secureInternetAll
+ if urlString != "" {
+ printConfig(urlString, true)
+ return
+ } else if secureInternetString != "" {
+ printConfig(secureInternetString, false)
+ return
+ } else if secureInternetAllString != "" {
+ getSecureInternetAll(secureInternetAllString)
return
}
diff --git a/exports/exports.go b/exports/exports.go
index f21a354..576e980 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -103,7 +103,7 @@ func Connect(name *C.char, url *C.char) (*C.char, *C.char) {
if stateErr != nil {
return nil, C.CString(ErrorToString(stateErr))
}
- config, configErr := state.Connect(C.GoString(url))
+ config, configErr := state.ConnectInstituteAccess(C.GoString(url))
return C.CString(config), C.CString(ErrorToString(configErr))
}
diff --git a/internal/discovery.go b/internal/discovery.go
index 8c0acc7..59281bd 100644
--- a/internal/discovery.go
+++ b/internal/discovery.go
@@ -57,8 +57,8 @@ type ServersList struct {
type Discovery struct {
Organizations OrganizationList
Servers ServersList
- FSM *FSM
- Logger *FileLogger
+ FSM *FSM
+ Logger *FileLogger
}
// Helper function that gets a disco json
diff --git a/internal/fsm.go b/internal/fsm.go
index 6997d92..1bcc479 100644
--- a/internal/fsm.go
+++ b/internal/fsm.go
@@ -4,8 +4,8 @@ import (
"fmt"
"os"
"os/exec"
- "sort"
"path"
+ "sort"
)
type (
@@ -94,7 +94,7 @@ type FSM struct {
Current FSMStateID
// Info to be passed from the parent state
- Name string
+ Name string
StateCallback func(string, string, string)
Logger *FileLogger
Directory string
@@ -107,7 +107,7 @@ func (fsm *FSM) Init(name string, callback func(string, string, string), logger
NO_SERVER: {{CHOSEN_SERVER, "User chooses a server"}},
CHOSEN_SERVER: {{AUTHORIZED, "Found tokens in config"}, {OAUTH_STARTED, "No tokens found in config"}},
OAUTH_STARTED: {{AUTHORIZED, "User authorizes with browser"}, {CHOSEN_SERVER, "Cancel OAuth"}},
- AUTHORIZED: {{OAUTH_STARTED, "Re-authorize with OAuth"}, {REQUEST_CONFIG, "Client requests a config"}},
+ AUTHORIZED: {{OAUTH_STARTED, "Re-authorize with OAuth"}, {REQUEST_CONFIG, "Client requests a config"}},
REQUEST_CONFIG: {{ASK_PROFILE, "Multiple profiles found"}, {HAS_CONFIG, "Success, only one profile"}},
ASK_PROFILE: {{HAS_CONFIG, "User chooses profile and success"}},
HAS_CONFIG: {{CONNECTED, "OS reports connected"}},
diff --git a/internal/oauth.go b/internal/oauth.go
index 9d17777..98af5a4 100644
--- a/internal/oauth.go
+++ b/internal/oauth.go
@@ -326,8 +326,7 @@ func (oauth *OAuth) NeedsRelogin() bool {
return true
}
-type OAuthCancelledCallbackError struct {
-}
+type OAuthCancelledCallbackError struct{}
func (e *OAuthCancelledCallbackError) Error() string {
return fmt.Sprintf("Client cancelled OAuth")
diff --git a/internal/server.go b/internal/server.go
index aa21a97..1d6f1e1 100644
--- a/internal/server.go
+++ b/internal/server.go
@@ -17,8 +17,9 @@ type Server struct {
}
type Servers struct {
- List map[string]*Server `json:"list"`
- Current string `json:"current"`
+ List map[string]*Server `json:"list"`
+ Current string `json:"current"`
+ SecureHome string `json:"secure_home"`
}
func (servers *Servers) GetCurrentServer() (*Server, error) {
@@ -57,7 +58,10 @@ func (server *Server) EnsureTokens() error {
return nil
}
-func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger) (*Server, error) {
+func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger, makeCurrent bool) (*Server, error) {
+ if url == "" {
+ return nil, errors.New("Emtpy URL to ensure Server")
+ }
if servers.List == nil {
servers.List = make(map[string]*Server)
}
@@ -73,10 +77,41 @@ func (servers *Servers) EnsureServer(url string, fsm *FSM, logger *FileLogger) (
return nil, serverInitErr
}
servers.List[url] = server
- servers.Current = url
+
+ if makeCurrent {
+ servers.Current = url
+ }
+ return server, nil
+}
+
+func (servers *Servers) getSecureInternetHome() (*Server, error) {
+ server, exists := servers.List[servers.SecureHome]
+
+ if !exists || server == nil {
+ return nil, errors.New("No secure internet home found")
+ }
+
return server, nil
}
+func (servers *Servers) EnsureSecureHome(server *Server) {
+ if servers.SecureHome == "" {
+ servers.SecureHome = server.BaseURL
+ }
+}
+
+func (servers *Servers) CopySecureInternetOAuth(server *Server) error {
+ secureHome, secureHomeErr := servers.getSecureInternetHome()
+
+ if secureHomeErr != nil {
+ return secureHomeErr
+ }
+
+ // Forward token properties
+ server.OAuth = secureHome.OAuth
+ return nil
+}
+
type ServerProfile struct {
ID string `json:"profile_id"`
DisplayName string `json:"display_name"`
@@ -151,7 +186,7 @@ func (server *Server) getCurrentProfile() (*ServerProfile, error) {
return &profile, nil
}
}
- return nil, errors.New("no profile found for id")
+ return nil, errors.New(fmt.Sprintf("no profile found for id %s", profile_id))
}
func (server *Server) getConfigWithProfile() (string, error) {
diff --git a/state.go b/state.go
index 3ca0a4b..c69cf37 100644
--- a/state.go
+++ b/state.go
@@ -2,6 +2,7 @@ package eduvpn
import (
"errors"
+
"github.com/jwijenbergh/eduvpn-common/internal"
)
@@ -86,16 +87,26 @@ func (state *VPNState) CancelOAuth() error {
return nil
}
-func (state *VPNState) Connect(url string) (string, error) {
+func (state *VPNState) connectWithOptions(url string, isSecureInternet bool) (string, error) {
if state.FSM.InState(internal.DEREGISTERED) {
return "", errors.New("app not registered")
}
// New server chosen, ensure the server is fresh
- server, serverErr := state.Servers.EnsureServer(url, &state.FSM, &state.Logger)
+ server, serverErr := state.Servers.EnsureServer(url, &state.FSM, &state.Logger, true)
if serverErr != nil {
return "", serverErr
}
+
+ // When we connect to secure internet, copy over the tokens from the home server
+ if isSecureInternet {
+ // Ensure the secure home server
+ state.Servers.EnsureServer(state.Servers.SecureHome, &state.FSM, &state.Logger, false)
+
+ // Copy the tokens
+ state.Servers.CopySecureInternetOAuth(server)
+ }
+
// Make sure we are in the chosen state if available
state.FSM.GoTransition(internal.CHOSEN_SERVER)
// Relogin with oauth
@@ -113,6 +124,9 @@ func (state *VPNState) Connect(url string) (string, error) {
state.FSM.GoTransition(internal.AUTHORIZED)
}
+ // Set the home server if it is not set already
+ state.Servers.EnsureSecureHome(server)
+
state.FSM.GoTransition(internal.REQUEST_CONFIG)
config, configErr := server.GetConfig()
@@ -126,6 +140,14 @@ func (state *VPNState) Connect(url string) (string, error) {
return config, nil
}
+func (state *VPNState) ConnectInstituteAccess(url string) (string, error) {
+ return state.connectWithOptions(url, false)
+}
+
+func (state *VPNState) ConnectSecureInternet(url string) (string, error) {
+ return state.connectWithOptions(url, true)
+}
+
func (state *VPNState) GetDiscoOrganizations() (string, error) {
if state.FSM.InState(internal.DEREGISTERED) {
return "", errors.New("app not registered")
@@ -133,7 +155,6 @@ func (state *VPNState) GetDiscoOrganizations() (string, error) {
return state.Discovery.GetOrganizationsList()
}
-
func (state *VPNState) GetDiscoServers() (string, error) {
if state.FSM.InState(internal.DEREGISTERED) {
return "", errors.New("app not registered")
diff --git a/state_test.go b/state_test.go
index 5f37147..4320a6d 100644
--- a/state_test.go
+++ b/state_test.go
@@ -59,7 +59,7 @@ func Test_server(t *testing.T) {
stateCallback(t, old, new, data, state)
}, false)
- _, configErr := state.Connect(serverURI)
+ _, configErr := state.ConnectInstituteAccess(serverURI)
if configErr != nil {
t.Errorf("Connect error: %v", configErr)
@@ -82,7 +82,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters internal.URLParameter
}
}, false)
- _, configErr := state.Connect(serverURI)
+ _, configErr := state.ConnectInstituteAccess(serverURI)
if !errors.As(configErr, expectedErr) {
t.Errorf("error %T = %v, wantErr %T", configErr, configErr, expectedErr)
@@ -130,7 +130,7 @@ func Test_token_expired(t *testing.T) {
stateCallback(t, old, new, data, state)
}, false)
- _, configErr := state.Connect(serverURI)
+ _, configErr := state.ConnectInstituteAccess(serverURI)
if configErr != nil {
t.Errorf("Connect error before expired: %v", configErr)
@@ -174,7 +174,7 @@ func Test_token_invalid(t *testing.T) {
stateCallback(t, old, new, data, state)
}, false)
- _, configErr := state.Connect(serverURI)
+ _, configErr := state.ConnectInstituteAccess(serverURI)
if configErr != nil {
t.Errorf("Connect error before invalid: %v", configErr)