summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJeroen Wijenbergh <jeroenwijenbergh@protonmail.com>2022-04-19 15:02:45 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-04-19 15:02:45 +0200
commit723ecacc8528be0e96db42392f1781ddf5894bea (patch)
tree1debf1d6d0c50adb32939db3cc84e5130d1fb818 /src
parent5f40a8d10a17182f744cb7ac11087d170dd49560 (diff)
Profiles: Implement basic functionality for sending a profile_id
Diffstat (limited to 'src')
-rw-r--r--src/api.go8
-rw-r--r--src/fsm.go16
-rw-r--r--src/openvpn.go4
-rw-r--r--src/server.go55
-rw-r--r--src/server_test.go18
-rw-r--r--src/state.go7
-rw-r--r--src/wireguard.go4
7 files changed, 68 insertions, 44 deletions
diff --git a/src/api.go b/src/api.go
index 1f058fc..a11c907 100644
--- a/src/api.go
+++ b/src/api.go
@@ -62,9 +62,7 @@ func (server *Server) APIInfo() error {
}
server.Profiles = structure
-
- // FIXME: Implement profile selection callback
- server.Profiles.Current = 0
+ server.ProfilesRaw = string(body)
return nil
}
@@ -75,7 +73,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str
}
urlForm := url.Values{
- "profile_id": {"default"},
+ "profile_id": {profile_id},
"public_key": {pubkey},
}
header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
@@ -94,7 +92,7 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro
}
urlForm := url.Values{
- "profile_id": {"default"},
+ "profile_id": {profile_id},
}
header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
diff --git a/src/fsm.go b/src/fsm.go
index 0ed7a37..7a54fd8 100644
--- a/src/fsm.go
+++ b/src/fsm.go
@@ -23,6 +23,9 @@ const (
// Authenticated means the OAuth process has finished and the user is now authenticated with the server
AUTHENTICATED
+ // Ask profile means the go code is asking for a profile selection from the ui
+ ASK_PROFILE
+
// Connected means the user has been connected to the server
CONNECTED
)
@@ -37,6 +40,8 @@ func (s FSMStateID) String() string {
return "Chosen_Server"
case OAUTH_STARTED:
return "OAuth_Started"
+ case ASK_PROFILE:
+ return "Ask_Profile"
case AUTHENTICATED:
return "Authenticated"
case CONNECTED:
@@ -88,19 +93,21 @@ func (eduvpn *VPNState) writeGraph() {
f.WriteString(graph)
}
-func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) bool {
+func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) (bool, string) {
ok := eduvpn.HasTransition(newState)
+ received := ""
+
if ok {
oldState := eduvpn.FSM.Current
eduvpn.FSM.Current = newState
if eduvpn.Debug {
eduvpn.writeGraph()
}
- eduvpn.StateCallback(oldState.String(), newState.String(), data)
+ received = eduvpn.StateCallback(oldState.String(), newState.String(), data)
}
- return ok
+ return ok, received
}
func (eduvpn *VPNState) generateDotGraph() string {
@@ -141,7 +148,8 @@ func (eduvpn *VPNState) InitializeFSM() {
NO_SERVER: {{CHOSEN_SERVER, "User chooses a server"}},
CHOSEN_SERVER: {{AUTHENTICATED, "Found tokens in config"}, {OAUTH_STARTED, "No tokens found in config"}},
OAUTH_STARTED: {{AUTHENTICATED, "User authorizes with browser"}},
- AUTHENTICATED: {{CONNECTED, "OS reports connected"}, {OAUTH_STARTED, "Re-authenticate with OAuth"}},
+ AUTHENTICATED: {{CONNECTED, "OS reports connected"}, {OAUTH_STARTED, "Re-authenticate with OAuth"}, {ASK_PROFILE, "Connect, multiple profiles detected"}},
+ ASK_PROFILE: {{CONNECTED, "OS reports connected"}},
CONNECTED: {{AUTHENTICATED, "OS reports disconnected"}},
},
Current: DEREGISTERED,
diff --git a/src/openvpn.go b/src/openvpn.go
index 2cab2c4..0cf8e36 100644
--- a/src/openvpn.go
+++ b/src/openvpn.go
@@ -1,7 +1,7 @@
package eduvpn
-func (server *Server) OpenVPNGetConfig() (string, error) {
- configOpenVPN, _, configErr := server.APIConnectOpenVPN("default")
+func (server *Server) OpenVPNGetConfig(profile_id string) (string, error) {
+ configOpenVPN, _, configErr := server.APIConnectOpenVPN(profile_id)
if configErr != nil {
return "", configErr
diff --git a/src/server.go b/src/server.go
index 8a600d4..20d9136 100644
--- a/src/server.go
+++ b/src/server.go
@@ -10,6 +10,7 @@ type Server struct {
Endpoints *ServerEndpoints `json:"endpoints"`
OAuth *OAuth `json:"oauth"`
Profiles *ServerProfileInfo `json:"profiles"`
+ ProfilesRaw string `json:"profiles_raw"`
}
type ServerProfile struct {
@@ -20,7 +21,7 @@ type ServerProfile struct {
}
type ServerProfileInfo struct {
- Current uint8 `json:"current_profile"`
+ Current string `json:"current_profile"`
Info struct {
ProfileList []ServerProfile `json:"profile_list"`
} `json:"info"`
@@ -84,17 +85,6 @@ func (server *Server) GetEndpoints() error {
return nil
}
-func (profiles *ServerProfileInfo) getCurrentProfile() (*ServerProfile, error) {
- if profiles.Info.ProfileList == nil {
- return nil, errors.New("No server profiles")
- }
-
- if (int)(profiles.Current) >= len(profiles.Info.ProfileList) {
- return nil, errors.New("Invalid profile")
- }
- return &profiles.Info.ProfileList[profiles.Current], nil
-}
-
func (profile *ServerProfile) supportsWireguard() bool {
for _, proto := range profile.VPNProtoList {
if proto == "wireguard" {
@@ -104,12 +94,31 @@ func (profile *ServerProfile) supportsWireguard() bool {
return false
}
-func (server *Server) GetCurrentProfile() (*ServerProfile, error) {
- if server.Profiles == nil {
- return nil, errors.New("No server profiles found")
+func (server *Server) getProfileForID(profile_id string) (*ServerProfile, error) {
+ for _, profile := range server.Profiles.Info.ProfileList {
+ if profile.ID == profile_id {
+ return &profile, nil
+ }
}
+ return nil, errors.New("no profile found for id")
+}
+
+func (server *Server) getConfigWithProfile(profile_id string) (string, error) {
+ profile, profileErr := server.getProfileForID(profile_id)
- return server.Profiles.getCurrentProfile()
+ if profileErr != nil {
+ return "", profileErr
+ }
+
+ if profile.supportsWireguard() {
+ return server.WireguardGetConfig(profile_id)
+ }
+ return server.OpenVPNGetConfig(profile_id)
+}
+
+func (server *Server) askForProfileID() (string, error) {
+ _, profile_id := GetVPNState().GoTransition(ASK_PROFILE, server.ProfilesRaw)
+ return profile_id, nil
}
func (server *Server) GetConfig() (string, error) {
@@ -119,14 +128,16 @@ func (server *Server) GetConfig() (string, error) {
return "", infoErr
}
- profile, profileErr := server.GetCurrentProfile()
+ // Set the current profile if there is only one profile
+ if len(server.Profiles.Info.ProfileList) == 1 {
+ return server.getConfigWithProfile(server.Profiles.Info.ProfileList[0].ID)
+ }
+
+ profile_id, profileErr := server.askForProfileID()
if profileErr != nil {
- return "", profileErr
+ return "", nil
}
- if profile.supportsWireguard() {
- return server.WireguardGetConfig()
- }
- return server.OpenVPNGetConfig()
+ return server.getConfigWithProfile(profile_id)
}
diff --git a/src/server_test.go b/src/server_test.go
index f914583..5b6ec5a 100644
--- a/src/server_test.go
+++ b/src/server_test.go
@@ -35,10 +35,13 @@ func LoginOAuthSelenium(t *testing.T, url string) {
}
}
-func StateCallback(t *testing.T, oldState string, newState string, data string) {
+func StateCallback(t *testing.T, oldState string, newState string, data string) string {
if newState == "OAuth_Started" {
go LoginOAuthSelenium(t, data)
}
+
+ // We have no data to send back
+ return ""
}
func Test_server(t *testing.T) {
@@ -47,8 +50,8 @@ func Test_server(t *testing.T) {
// Do not verify because during testing, the cert is self-signed
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
- state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) {
- StateCallback(t, old, new, data)
+ state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) string {
+ return StateCallback(t, old, new, data)
}, false)
_, configErr := state.Connect("https://eduvpnserver")
@@ -65,7 +68,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect
// Do not verify because during testing, the cert is self-signed
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
- state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) {
+ state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) string {
if newState == "OAuth_Started" {
baseURL := "http://127.0.0.1:8000/callback"
url, err := HTTPConstructURL(baseURL, parameters)
@@ -73,7 +76,10 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect
t.Errorf("Error: Constructing url %s with parameters %s", baseURL, fmt.Sprint(parameters))
}
go http.Get(url)
+
}
+ // We have no data to send back
+ return ""
}, false)
_, configErr := state.Connect("https://eduvpnserver")
@@ -123,8 +129,8 @@ func Test_token_expired(t *testing.T) {
// Do not verify because during testing, the cert is self-signed
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
- state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) {
- StateCallback(t, old, new, data)
+ state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) string {
+ return StateCallback(t, old, new, data)
}, false)
diff --git a/src/state.go b/src/state.go
index 1eb8ebd..95bb75c 100644
--- a/src/state.go
+++ b/src/state.go
@@ -8,7 +8,8 @@ type VPNState struct {
// Info passed by the client
ConfigDirectory string `json:"-"`
Name string `json:"-"`
- StateCallback func(string, string, string) `json:"-"`
+ StateCallback func(string, string, string) string `json:"-"`
+ StateCallbackData string `json:"-"`
// The chosen server
Server *Server `json:"server"`
@@ -26,7 +27,7 @@ type VPNState struct {
Debug bool `json:"-"`
}
-func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error {
+func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string) string, debug bool) error {
state.InitializeFSM()
if !state.InState(DEREGISTERED) {
return errors.New("app already registered")
@@ -71,7 +72,7 @@ func (state *VPNState) Deregister() error {
}
func (state *VPNState) Connect(url string) (string, error) {
- if state.Server == nil {
+ if state.Server == nil || state.Server.BaseURL != url {
state.Server = &Server{}
}
initializeErr := state.Server.Initialize(url)
diff --git a/src/wireguard.go b/src/wireguard.go
index b701b1d..3f7da40 100644
--- a/src/wireguard.go
+++ b/src/wireguard.go
@@ -26,7 +26,7 @@ func wireguardConfigAddKey(config string, key wgtypes.Key) string {
return interface_re.ReplaceAllString(config, to_replace)
}
-func (server *Server) WireguardGetConfig() (string, error) {
+func (server *Server) WireguardGetConfig(profile_id string) (string, error) {
wireguardKey, wireguardErr := wireguardGenerateKey()
if wireguardErr != nil {
@@ -34,7 +34,7 @@ func (server *Server) WireguardGetConfig() (string, error) {
}
wireguardPublicKey := wireguardKey.PublicKey().String()
- configWireguard, _, configErr := server.APIConnectWireguard("default", wireguardPublicKey)
+ configWireguard, _, configErr := server.APIConnectWireguard(profile_id, wireguardPublicKey)
if configErr != nil {
return "", configErr