summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-04-20 14:28:08 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-04-20 14:28:08 +0200
commit1c54936626a4a30d0c6f69576a06ba3661f39dc6 (patch)
tree426c76b4c55cf9a9efbc7bd1aa957baec57b2892
parent77c9f266553cbadfd5fb150a26c2162b705f151e (diff)
Profiles: Implement SetProfileID instead of getting generic data
-rw-r--r--exports/exports.go31
-rw-r--r--src/fsm.go8
-rw-r--r--src/openvpn.go3
-rw-r--r--src/server.go26
-rw-r--r--src/server_test.go17
-rw-r--r--src/state.go10
-rw-r--r--src/wireguard.go3
-rw-r--r--wrappers/python/eduvpncommon/__init__.py2
-rw-r--r--wrappers/python/eduvpncommon/main.py8
-rw-r--r--wrappers/python/main.py2
10 files changed, 58 insertions, 52 deletions
diff --git a/exports/exports.go b/exports/exports.go
index 7d2b1ce..af71eba 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -18,9 +18,9 @@ import "github.com/jwijenbergh/eduvpn-common/src"
var P_StateCallback C.PythonCB
-func StateCallback(old_state string, new_state string, data string) string {
+func StateCallback(old_state string, new_state string, data string) {
if P_StateCallback == nil {
- return ""
+ return
}
oldState_c := C.CString(old_state)
newState_c := C.CString(new_state)
@@ -29,12 +29,6 @@ func StateCallback(old_state string, new_state string, data string) string {
C.free(unsafe.Pointer(oldState_c))
C.free(unsafe.Pointer(newState_c))
C.free(unsafe.Pointer(data_c))
-
- // Get state data and reset
- state := eduvpn.GetVPNState()
- received_data := state.StateCallbackData
- state.StateCallbackData = ""
- return received_data
}
//export Register
@@ -81,9 +75,24 @@ func GetServersList() (*C.char, *C.char) {
return C.CString(servers), C.CString(ErrorToString(serversErr))
}
-//export SendData
-func SendData(data *C.char) {
- eduvpn.GetVPNState().StateCallbackData = C.GoString(data)
+//export SetProfileID
+func SetProfileID(data *C.char) {
+ state := eduvpn.GetVPNState()
+
+ // No server
+ if state.Server == nil {
+ return
+ }
+
+ // No profiles for server
+ if state.Server.Profiles == nil {
+ return
+ }
+
+ // Set current profile to id
+ profile_id := C.GoString(data)
+
+ state.Server.Profiles.Current = profile_id
}
//export FreeString
diff --git a/src/fsm.go b/src/fsm.go
index a4e7be5..7eb1ca1 100644
--- a/src/fsm.go
+++ b/src/fsm.go
@@ -119,21 +119,19 @@ func (eduvpn *VPNState) writeGraph() {
f.WriteString(graph)
}
-func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) (bool, string) {
+func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) bool {
ok := eduvpn.HasTransition(newState)
- received := ""
-
if ok {
oldState := eduvpn.FSM.Current
eduvpn.FSM.Current = newState
if eduvpn.Debug {
eduvpn.writeGraph()
}
- received = eduvpn.StateCallback(oldState.String(), newState.String(), data)
+ eduvpn.StateCallback(oldState.String(), newState.String(), data)
}
- return ok, received
+ return ok
}
func (eduvpn *VPNState) generateDotGraph() string {
diff --git a/src/openvpn.go b/src/openvpn.go
index 0cf8e36..95e1328 100644
--- a/src/openvpn.go
+++ b/src/openvpn.go
@@ -1,6 +1,7 @@
package eduvpn
-func (server *Server) OpenVPNGetConfig(profile_id string) (string, error) {
+func (server *Server) OpenVPNGetConfig() (string, error) {
+ profile_id := server.Profiles.Current
configOpenVPN, _, configErr := server.APIConnectOpenVPN(profile_id)
if configErr != nil {
diff --git a/src/server.go b/src/server.go
index d8afb01..b7d55cb 100644
--- a/src/server.go
+++ b/src/server.go
@@ -94,7 +94,8 @@ func (profile *ServerProfile) supportsWireguard() bool {
return false
}
-func (server *Server) getProfileForID(profile_id string) (*ServerProfile, error) {
+func (server *Server) getCurrentProfile() (*ServerProfile, error) {
+ profile_id := server.Profiles.Current
for _, profile := range server.Profiles.Info.ProfileList {
if profile.ID == profile_id {
return &profile, nil
@@ -103,28 +104,28 @@ func (server *Server) getProfileForID(profile_id string) (*ServerProfile, error)
return nil, errors.New("no profile found for id")
}
-func (server *Server) getConfigWithProfile(profile_id string) (string, error) {
+func (server *Server) getConfigWithProfile() (string, error) {
if !GetVPNState().HasTransition(HAS_CONFIG) {
return "", errors.New("cannot get a config with a profile, invalid state")
}
- profile, profileErr := server.getProfileForID(profile_id)
+ profile, profileErr := server.getCurrentProfile()
if profileErr != nil {
return "", profileErr
}
if profile.supportsWireguard() {
- return server.WireguardGetConfig(profile_id)
+ return server.WireguardGetConfig()
}
- return server.OpenVPNGetConfig(profile_id)
+ return server.OpenVPNGetConfig()
}
-func (server *Server) askForProfileID() (string, error) {
+func (server *Server) askForProfileID() error {
if !GetVPNState().HasTransition(ASK_PROFILE) {
- return "", errors.New("cannot ask for a profile id, invalid state")
+ return errors.New("cannot ask for a profile id, invalid state")
}
- _, profile_id := GetVPNState().GoTransition(ASK_PROFILE, server.ProfilesRaw)
- return profile_id, nil
+ GetVPNState().GoTransition(ASK_PROFILE, server.ProfilesRaw)
+ return nil
}
func (server *Server) GetConfig() (string, error) {
@@ -139,14 +140,15 @@ func (server *Server) GetConfig() (string, error) {
// Set the current profile if there is only one profile
if len(server.Profiles.Info.ProfileList) == 1 {
- return server.getConfigWithProfile(server.Profiles.Info.ProfileList[0].ID)
+ server.Profiles.Current = server.Profiles.Info.ProfileList[0].ID
+ return server.getConfigWithProfile()
}
- profile_id, profileErr := server.askForProfileID()
+ profileErr := server.askForProfileID()
if profileErr != nil {
return "", nil
}
- return server.getConfigWithProfile(profile_id)
+ return server.getConfigWithProfile()
}
diff --git a/src/server_test.go b/src/server_test.go
index 3045983..6199884 100644
--- a/src/server_test.go
+++ b/src/server_test.go
@@ -35,13 +35,10 @@ func LoginOAuthSelenium(t *testing.T, url string) {
}
}
-func StateCallback(t *testing.T, oldState string, newState string, data string) string {
+func StateCallback(t *testing.T, oldState string, newState string, data string) {
if newState == "OAuth_Started" {
go LoginOAuthSelenium(t, data)
}
-
- // We have no data to send back
- return ""
}
func Test_server(t *testing.T) {
@@ -50,8 +47,8 @@ func Test_server(t *testing.T) {
// Do not verify because during testing, the cert is self-signed
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
- state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) string {
- return StateCallback(t, old, new, data)
+ state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) {
+ StateCallback(t, old, new, data)
}, false)
_, configErr := state.Connect("https://eduvpnserver")
@@ -68,7 +65,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect
// Do not verify because during testing, the cert is self-signed
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
- state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) string {
+ state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) {
if newState == "OAuth_Started" {
baseURL := "http://127.0.0.1:8000/callback"
url, err := HTTPConstructURL(baseURL, parameters)
@@ -78,8 +75,6 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect
go http.Get(url)
}
- // We have no data to send back
- return ""
}, false)
_, configErr := state.Connect("https://eduvpnserver")
@@ -129,8 +124,8 @@ func Test_token_expired(t *testing.T) {
// Do not verify because during testing, the cert is self-signed
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
- state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) string {
- return StateCallback(t, old, new, data)
+ state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) {
+ StateCallback(t, old, new, data)
}, false)
_, configErr := state.Connect("https://eduvpnserver")
diff --git a/src/state.go b/src/state.go
index b9bb920..6c740c1 100644
--- a/src/state.go
+++ b/src/state.go
@@ -6,10 +6,10 @@ import (
type VPNState struct {
// Info passed by the client
- ConfigDirectory string `json:"-"`
- Name string `json:"-"`
- StateCallback func(string, string, string) string `json:"-"`
- StateCallbackData string `json:"-"`
+ ConfigDirectory string `json:"-"`
+ Name string `json:"-"`
+ StateCallback func(string, string, string) `json:"-"`
+ StateCallbackData string `json:"-"`
// The chosen server
Server *Server `json:"server"`
@@ -27,7 +27,7 @@ type VPNState struct {
Debug bool `json:"-"`
}
-func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string) string, debug bool) error {
+func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error {
state.InitializeFSM()
if !state.InState(DEREGISTERED) {
return errors.New("app already registered")
diff --git a/src/wireguard.go b/src/wireguard.go
index 3f7da40..2f1c41c 100644
--- a/src/wireguard.go
+++ b/src/wireguard.go
@@ -26,7 +26,8 @@ func wireguardConfigAddKey(config string, key wgtypes.Key) string {
return interface_re.ReplaceAllString(config, to_replace)
}
-func (server *Server) WireguardGetConfig(profile_id string) (string, error) {
+func (server *Server) WireguardGetConfig() (string, error) {
+ profile_id := server.Profiles.Current
wireguardKey, wireguardErr := wireguardGenerateKey()
if wireguardErr != nil {
diff --git a/wrappers/python/eduvpncommon/__init__.py b/wrappers/python/eduvpncommon/__init__.py
index 50af8eb..d7fce2f 100644
--- a/wrappers/python/eduvpncommon/__init__.py
+++ b/wrappers/python/eduvpncommon/__init__.py
@@ -38,7 +38,7 @@ lib.Deregister.argtypes, lib.Deregister.restype = [], None
lib.Register.argtypes, lib.Register.restype = [c_char_p, c_char_p, VPNStateChange, c_int], c_void_p
lib.GetOrganizationsList.argtypes, lib.GetOrganizationsList.restype = [], DataError
lib.GetServersList.argtypes, lib.GetServersList.restype = [], DataError
-lib.SendData.argtypes, lib.SendData.restype = [c_char_p], None
+lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p], None
# We have to use c_void_p instead of c_char_p to free it properly
# See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error
lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None
diff --git a/wrappers/python/eduvpncommon/main.py b/wrappers/python/eduvpncommon/main.py
index 8d2de9c..694fc59 100644
--- a/wrappers/python/eduvpncommon/main.py
+++ b/wrappers/python/eduvpncommon/main.py
@@ -47,8 +47,8 @@ def register_callback(eduvpn):
)
-def SendData(data):
- lib.SendData(data.encode("utf-8"))
+def SetProfileID(profile_id):
+ lib.SetProfileID(profile_id.encode("utf-8"))
class EduVPN(object):
@@ -79,8 +79,8 @@ class EduVPN(object):
def callback(self, old_state, new_state, data):
self.event.run(old_state, new_state, data)
- def send_data(self, data):
- return SendData(data)
+ def set_profile(self, profile_id):
+ return SetProfileID(profile_id)
class EventHandler(object):
diff --git a/wrappers/python/main.py b/wrappers/python/main.py
index bc29239..be9ab6c 100644
--- a/wrappers/python/main.py
+++ b/wrappers/python/main.py
@@ -14,7 +14,7 @@ def oauth_initialized(url):
@_eduvpn.event.on("Ask_Profile", eduvpn.StateType.Enter)
def ask_profile(profiles):
print("ASK PROFILE CB", profiles)
- _eduvpn.send_data("prefer-openvpn")
+ _eduvpn.set_profile("prefer-openvpn")
success = _eduvpn.register(debug=True)