summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeroen Wijenbergh <jeroenwijenbergh@protonmail.com>2022-04-19 15:02:45 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-04-19 15:02:45 +0200
commit723ecacc8528be0e96db42392f1781ddf5894bea (patch)
tree1debf1d6d0c50adb32939db3cc84e5130d1fb818
parent5f40a8d10a17182f744cb7ac11087d170dd49560 (diff)
Profiles: Implement basic functionality for sending a profile_id
-rw-r--r--exports/exports.go15
-rw-r--r--src/api.go8
-rw-r--r--src/fsm.go16
-rw-r--r--src/openvpn.go4
-rw-r--r--src/server.go55
-rw-r--r--src/server_test.go18
-rw-r--r--src/state.go7
-rw-r--r--src/wireguard.go4
-rw-r--r--wrappers/python/eduvpncommon/__init__.py1
-rw-r--r--wrappers/python/eduvpncommon/main.py7
-rw-r--r--wrappers/python/main.py6
11 files changed, 95 insertions, 46 deletions
diff --git a/exports/exports.go b/exports/exports.go
index e34721e..7d2b1ce 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -18,9 +18,9 @@ import "github.com/jwijenbergh/eduvpn-common/src"
var P_StateCallback C.PythonCB
-func StateCallback(old_state string, new_state string, data string) {
+func StateCallback(old_state string, new_state string, data string) string {
if P_StateCallback == nil {
- return
+ return ""
}
oldState_c := C.CString(old_state)
newState_c := C.CString(new_state)
@@ -29,6 +29,12 @@ func StateCallback(old_state string, new_state string, data string) {
C.free(unsafe.Pointer(oldState_c))
C.free(unsafe.Pointer(newState_c))
C.free(unsafe.Pointer(data_c))
+
+ // Get state data and reset
+ state := eduvpn.GetVPNState()
+ received_data := state.StateCallbackData
+ state.StateCallbackData = ""
+ return received_data
}
//export Register
@@ -75,6 +81,11 @@ func GetServersList() (*C.char, *C.char) {
return C.CString(servers), C.CString(ErrorToString(serversErr))
}
+//export SendData
+func SendData(data *C.char) {
+ eduvpn.GetVPNState().StateCallbackData = C.GoString(data)
+}
+
//export FreeString
func FreeString(addr *C.char) {
C.free(unsafe.Pointer(addr))
diff --git a/src/api.go b/src/api.go
index 1f058fc..a11c907 100644
--- a/src/api.go
+++ b/src/api.go
@@ -62,9 +62,7 @@ func (server *Server) APIInfo() error {
}
server.Profiles = structure
-
- // FIXME: Implement profile selection callback
- server.Profiles.Current = 0
+ server.ProfilesRaw = string(body)
return nil
}
@@ -75,7 +73,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str
}
urlForm := url.Values{
- "profile_id": {"default"},
+ "profile_id": {profile_id},
"public_key": {pubkey},
}
header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
@@ -94,7 +92,7 @@ func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, erro
}
urlForm := url.Values{
- "profile_id": {"default"},
+ "profile_id": {profile_id},
}
header, connectBody, connectErr := server.apiAuthenticatedRetry(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
diff --git a/src/fsm.go b/src/fsm.go
index 0ed7a37..7a54fd8 100644
--- a/src/fsm.go
+++ b/src/fsm.go
@@ -23,6 +23,9 @@ const (
// Authenticated means the OAuth process has finished and the user is now authenticated with the server
AUTHENTICATED
+ // Ask profile means the go code is asking for a profile selection from the ui
+ ASK_PROFILE
+
// Connected means the user has been connected to the server
CONNECTED
)
@@ -37,6 +40,8 @@ func (s FSMStateID) String() string {
return "Chosen_Server"
case OAUTH_STARTED:
return "OAuth_Started"
+ case ASK_PROFILE:
+ return "Ask_Profile"
case AUTHENTICATED:
return "Authenticated"
case CONNECTED:
@@ -88,19 +93,21 @@ func (eduvpn *VPNState) writeGraph() {
f.WriteString(graph)
}
-func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) bool {
+func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) (bool, string) {
ok := eduvpn.HasTransition(newState)
+ received := ""
+
if ok {
oldState := eduvpn.FSM.Current
eduvpn.FSM.Current = newState
if eduvpn.Debug {
eduvpn.writeGraph()
}
- eduvpn.StateCallback(oldState.String(), newState.String(), data)
+ received = eduvpn.StateCallback(oldState.String(), newState.String(), data)
}
- return ok
+ return ok, received
}
func (eduvpn *VPNState) generateDotGraph() string {
@@ -141,7 +148,8 @@ func (eduvpn *VPNState) InitializeFSM() {
NO_SERVER: {{CHOSEN_SERVER, "User chooses a server"}},
CHOSEN_SERVER: {{AUTHENTICATED, "Found tokens in config"}, {OAUTH_STARTED, "No tokens found in config"}},
OAUTH_STARTED: {{AUTHENTICATED, "User authorizes with browser"}},
- AUTHENTICATED: {{CONNECTED, "OS reports connected"}, {OAUTH_STARTED, "Re-authenticate with OAuth"}},
+ AUTHENTICATED: {{CONNECTED, "OS reports connected"}, {OAUTH_STARTED, "Re-authenticate with OAuth"}, {ASK_PROFILE, "Connect, multiple profiles detected"}},
+ ASK_PROFILE: {{CONNECTED, "OS reports connected"}},
CONNECTED: {{AUTHENTICATED, "OS reports disconnected"}},
},
Current: DEREGISTERED,
diff --git a/src/openvpn.go b/src/openvpn.go
index 2cab2c4..0cf8e36 100644
--- a/src/openvpn.go
+++ b/src/openvpn.go
@@ -1,7 +1,7 @@
package eduvpn
-func (server *Server) OpenVPNGetConfig() (string, error) {
- configOpenVPN, _, configErr := server.APIConnectOpenVPN("default")
+func (server *Server) OpenVPNGetConfig(profile_id string) (string, error) {
+ configOpenVPN, _, configErr := server.APIConnectOpenVPN(profile_id)
if configErr != nil {
return "", configErr
diff --git a/src/server.go b/src/server.go
index 8a600d4..20d9136 100644
--- a/src/server.go
+++ b/src/server.go
@@ -10,6 +10,7 @@ type Server struct {
Endpoints *ServerEndpoints `json:"endpoints"`
OAuth *OAuth `json:"oauth"`
Profiles *ServerProfileInfo `json:"profiles"`
+ ProfilesRaw string `json:"profiles_raw"`
}
type ServerProfile struct {
@@ -20,7 +21,7 @@ type ServerProfile struct {
}
type ServerProfileInfo struct {
- Current uint8 `json:"current_profile"`
+ Current string `json:"current_profile"`
Info struct {
ProfileList []ServerProfile `json:"profile_list"`
} `json:"info"`
@@ -84,17 +85,6 @@ func (server *Server) GetEndpoints() error {
return nil
}
-func (profiles *ServerProfileInfo) getCurrentProfile() (*ServerProfile, error) {
- if profiles.Info.ProfileList == nil {
- return nil, errors.New("No server profiles")
- }
-
- if (int)(profiles.Current) >= len(profiles.Info.ProfileList) {
- return nil, errors.New("Invalid profile")
- }
- return &profiles.Info.ProfileList[profiles.Current], nil
-}
-
func (profile *ServerProfile) supportsWireguard() bool {
for _, proto := range profile.VPNProtoList {
if proto == "wireguard" {
@@ -104,12 +94,31 @@ func (profile *ServerProfile) supportsWireguard() bool {
return false
}
-func (server *Server) GetCurrentProfile() (*ServerProfile, error) {
- if server.Profiles == nil {
- return nil, errors.New("No server profiles found")
+func (server *Server) getProfileForID(profile_id string) (*ServerProfile, error) {
+ for _, profile := range server.Profiles.Info.ProfileList {
+ if profile.ID == profile_id {
+ return &profile, nil
+ }
}
+ return nil, errors.New("no profile found for id")
+}
+
+func (server *Server) getConfigWithProfile(profile_id string) (string, error) {
+ profile, profileErr := server.getProfileForID(profile_id)
- return server.Profiles.getCurrentProfile()
+ if profileErr != nil {
+ return "", profileErr
+ }
+
+ if profile.supportsWireguard() {
+ return server.WireguardGetConfig(profile_id)
+ }
+ return server.OpenVPNGetConfig(profile_id)
+}
+
+func (server *Server) askForProfileID() (string, error) {
+ _, profile_id := GetVPNState().GoTransition(ASK_PROFILE, server.ProfilesRaw)
+ return profile_id, nil
}
func (server *Server) GetConfig() (string, error) {
@@ -119,14 +128,16 @@ func (server *Server) GetConfig() (string, error) {
return "", infoErr
}
- profile, profileErr := server.GetCurrentProfile()
+ // Set the current profile if there is only one profile
+ if len(server.Profiles.Info.ProfileList) == 1 {
+ return server.getConfigWithProfile(server.Profiles.Info.ProfileList[0].ID)
+ }
+
+ profile_id, profileErr := server.askForProfileID()
if profileErr != nil {
- return "", profileErr
+ return "", nil
}
- if profile.supportsWireguard() {
- return server.WireguardGetConfig()
- }
- return server.OpenVPNGetConfig()
+ return server.getConfigWithProfile(profile_id)
}
diff --git a/src/server_test.go b/src/server_test.go
index f914583..5b6ec5a 100644
--- a/src/server_test.go
+++ b/src/server_test.go
@@ -35,10 +35,13 @@ func LoginOAuthSelenium(t *testing.T, url string) {
}
}
-func StateCallback(t *testing.T, oldState string, newState string, data string) {
+func StateCallback(t *testing.T, oldState string, newState string, data string) string {
if newState == "OAuth_Started" {
go LoginOAuthSelenium(t, data)
}
+
+ // We have no data to send back
+ return ""
}
func Test_server(t *testing.T) {
@@ -47,8 +50,8 @@ func Test_server(t *testing.T) {
// Do not verify because during testing, the cert is self-signed
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
- state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) {
- StateCallback(t, old, new, data)
+ state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) string {
+ return StateCallback(t, old, new, data)
}, false)
_, configErr := state.Connect("https://eduvpnserver")
@@ -65,7 +68,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect
// Do not verify because during testing, the cert is self-signed
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
- state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) {
+ state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) string {
if newState == "OAuth_Started" {
baseURL := "http://127.0.0.1:8000/callback"
url, err := HTTPConstructURL(baseURL, parameters)
@@ -73,7 +76,10 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect
t.Errorf("Error: Constructing url %s with parameters %s", baseURL, fmt.Sprint(parameters))
}
go http.Get(url)
+
}
+ // We have no data to send back
+ return ""
}, false)
_, configErr := state.Connect("https://eduvpnserver")
@@ -123,8 +129,8 @@ func Test_token_expired(t *testing.T) {
// Do not verify because during testing, the cert is self-signed
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
- state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) {
- StateCallback(t, old, new, data)
+ state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) string {
+ return StateCallback(t, old, new, data)
}, false)
diff --git a/src/state.go b/src/state.go
index 1eb8ebd..95bb75c 100644
--- a/src/state.go
+++ b/src/state.go
@@ -8,7 +8,8 @@ type VPNState struct {
// Info passed by the client
ConfigDirectory string `json:"-"`
Name string `json:"-"`
- StateCallback func(string, string, string) `json:"-"`
+ StateCallback func(string, string, string) string `json:"-"`
+ StateCallbackData string `json:"-"`
// The chosen server
Server *Server `json:"server"`
@@ -26,7 +27,7 @@ type VPNState struct {
Debug bool `json:"-"`
}
-func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error {
+func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string) string, debug bool) error {
state.InitializeFSM()
if !state.InState(DEREGISTERED) {
return errors.New("app already registered")
@@ -71,7 +72,7 @@ func (state *VPNState) Deregister() error {
}
func (state *VPNState) Connect(url string) (string, error) {
- if state.Server == nil {
+ if state.Server == nil || state.Server.BaseURL != url {
state.Server = &Server{}
}
initializeErr := state.Server.Initialize(url)
diff --git a/src/wireguard.go b/src/wireguard.go
index b701b1d..3f7da40 100644
--- a/src/wireguard.go
+++ b/src/wireguard.go
@@ -26,7 +26,7 @@ func wireguardConfigAddKey(config string, key wgtypes.Key) string {
return interface_re.ReplaceAllString(config, to_replace)
}
-func (server *Server) WireguardGetConfig() (string, error) {
+func (server *Server) WireguardGetConfig(profile_id string) (string, error) {
wireguardKey, wireguardErr := wireguardGenerateKey()
if wireguardErr != nil {
@@ -34,7 +34,7 @@ func (server *Server) WireguardGetConfig() (string, error) {
}
wireguardPublicKey := wireguardKey.PublicKey().String()
- configWireguard, _, configErr := server.APIConnectWireguard("default", wireguardPublicKey)
+ configWireguard, _, configErr := server.APIConnectWireguard(profile_id, wireguardPublicKey)
if configErr != nil {
return "", configErr
diff --git a/wrappers/python/eduvpncommon/__init__.py b/wrappers/python/eduvpncommon/__init__.py
index 056ce18..50af8eb 100644
--- a/wrappers/python/eduvpncommon/__init__.py
+++ b/wrappers/python/eduvpncommon/__init__.py
@@ -38,6 +38,7 @@ lib.Deregister.argtypes, lib.Deregister.restype = [], None
lib.Register.argtypes, lib.Register.restype = [c_char_p, c_char_p, VPNStateChange, c_int], c_void_p
lib.GetOrganizationsList.argtypes, lib.GetOrganizationsList.restype = [], DataError
lib.GetServersList.argtypes, lib.GetServersList.restype = [], DataError
+lib.SendData.argtypes, lib.SendData.restype = [c_char_p], None
# We have to use c_void_p instead of c_char_p to free it properly
# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error
lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
diff --git a/wrappers/python/eduvpncommon/main.py b/wrappers/python/eduvpncommon/main.py
index 53b9831..8d2de9c 100644
--- a/wrappers/python/eduvpncommon/main.py
+++ b/wrappers/python/eduvpncommon/main.py
@@ -47,6 +47,10 @@ def register_callback(eduvpn):
)
+def SendData(data):
+ lib.SendData(data.encode("utf-8"))
+
+
class EduVPN(object):
def __init__(self, name, config_directory):
self.event_handler = EventHandler()
@@ -75,6 +79,9 @@ class EduVPN(object):
def callback(self, old_state, new_state, data):
self.event.run(old_state, new_state, data)
+ def send_data(self, data):
+ return SendData(data)
+
class EventHandler(object):
def __init__(self):
diff --git a/wrappers/python/main.py b/wrappers/python/main.py
index d6a568e..bc29239 100644
--- a/wrappers/python/main.py
+++ b/wrappers/python/main.py
@@ -11,6 +11,12 @@ def oauth_initialized(url):
webbrowser.open(url)
+@_eduvpn.event.on("Ask_Profile", eduvpn.StateType.Enter)
+def ask_profile(profiles):
+ print("ASK PROFILE CB", profiles)
+ _eduvpn.send_data("prefer-openvpn")
+
+
success = _eduvpn.register(debug=True)
if not success: