diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-04-20 14:28:08 +0200 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-04-20 14:28:08 +0200 |
| commit | 1c54936626a4a30d0c6f69576a06ba3661f39dc6 (patch) | |
| tree | 426c76b4c55cf9a9efbc7bd1aa957baec57b2892 | |
| parent | 77c9f266553cbadfd5fb150a26c2162b705f151e (diff) | |
Profiles: Implement SetProfileID instead of getting generic data
| -rw-r--r-- | exports/exports.go | 31 | ||||
| -rw-r--r-- | src/fsm.go | 8 | ||||
| -rw-r--r-- | src/openvpn.go | 3 | ||||
| -rw-r--r-- | src/server.go | 26 | ||||
| -rw-r--r-- | src/server_test.go | 17 | ||||
| -rw-r--r-- | src/state.go | 10 | ||||
| -rw-r--r-- | src/wireguard.go | 3 | ||||
| -rw-r--r-- | wrappers/python/eduvpncommon/__init__.py | 2 | ||||
| -rw-r--r-- | wrappers/python/eduvpncommon/main.py | 8 | ||||
| -rw-r--r-- | wrappers/python/main.py | 2 |
10 files changed, 58 insertions, 52 deletions
diff --git a/exports/exports.go b/exports/exports.go index 7d2b1ce..af71eba 100644 --- a/exports/exports.go +++ b/exports/exports.go @@ -18,9 +18,9 @@ import "github.com/jwijenbergh/eduvpn-common/src" var P_StateCallback C.PythonCB -func StateCallback(old_state string, new_state string, data string) string { +func StateCallback(old_state string, new_state string, data string) { if P_StateCallback == nil { - return "" + return } oldState_c := C.CString(old_state) newState_c := C.CString(new_state) @@ -29,12 +29,6 @@ func StateCallback(old_state string, new_state string, data string) string { C.free(unsafe.Pointer(oldState_c)) C.free(unsafe.Pointer(newState_c)) C.free(unsafe.Pointer(data_c)) - - // Get state data and reset - state := eduvpn.GetVPNState() - received_data := state.StateCallbackData - state.StateCallbackData = "" - return received_data } //export Register @@ -81,9 +75,24 @@ func GetServersList() (*C.char, *C.char) { return C.CString(servers), C.CString(ErrorToString(serversErr)) } -//export SendData -func SendData(data *C.char) { - eduvpn.GetVPNState().StateCallbackData = C.GoString(data) +//export SetProfileID +func SetProfileID(data *C.char) { + state := eduvpn.GetVPNState() + + // No server + if state.Server == nil { + return + } + + // No profiles for server + if state.Server.Profiles == nil { + return + } + + // Set current profile to id + profile_id := C.GoString(data) + + state.Server.Profiles.Current = profile_id } //export FreeString @@ -119,21 +119,19 @@ func (eduvpn *VPNState) writeGraph() { f.WriteString(graph) } -func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) (bool, string) { +func (eduvpn *VPNState) GoTransition(newState FSMStateID, data string) bool { ok := eduvpn.HasTransition(newState) - received := "" - if ok { oldState := eduvpn.FSM.Current eduvpn.FSM.Current = newState if eduvpn.Debug { eduvpn.writeGraph() } - received = eduvpn.StateCallback(oldState.String(), newState.String(), data) + eduvpn.StateCallback(oldState.String(), newState.String(), data) } - return ok, received + return ok } func (eduvpn *VPNState) generateDotGraph() string { diff --git a/src/openvpn.go b/src/openvpn.go index 0cf8e36..95e1328 100644 --- a/src/openvpn.go +++ b/src/openvpn.go @@ -1,6 +1,7 @@ package eduvpn -func (server *Server) OpenVPNGetConfig(profile_id string) (string, error) { +func (server *Server) OpenVPNGetConfig() (string, error) { + profile_id := server.Profiles.Current configOpenVPN, _, configErr := server.APIConnectOpenVPN(profile_id) if configErr != nil { diff --git a/src/server.go b/src/server.go index d8afb01..b7d55cb 100644 --- a/src/server.go +++ b/src/server.go @@ -94,7 +94,8 @@ func (profile *ServerProfile) supportsWireguard() bool { return false } -func (server *Server) getProfileForID(profile_id string) (*ServerProfile, error) { +func (server *Server) getCurrentProfile() (*ServerProfile, error) { + profile_id := server.Profiles.Current for _, profile := range server.Profiles.Info.ProfileList { if profile.ID == profile_id { return &profile, nil @@ -103,28 +104,28 @@ func (server *Server) getProfileForID(profile_id string) (*ServerProfile, error) return nil, errors.New("no profile found for id") } -func (server *Server) getConfigWithProfile(profile_id string) (string, error) { +func (server *Server) getConfigWithProfile() (string, error) { if !GetVPNState().HasTransition(HAS_CONFIG) { return "", errors.New("cannot get a config with a profile, invalid state") } - profile, profileErr := server.getProfileForID(profile_id) + profile, profileErr := server.getCurrentProfile() if profileErr != nil { return "", profileErr } if profile.supportsWireguard() { - return server.WireguardGetConfig(profile_id) + return server.WireguardGetConfig() } - return server.OpenVPNGetConfig(profile_id) + return server.OpenVPNGetConfig() } -func (server *Server) askForProfileID() (string, error) { +func (server *Server) askForProfileID() error { if !GetVPNState().HasTransition(ASK_PROFILE) { - return "", errors.New("cannot ask for a profile id, invalid state") + return errors.New("cannot ask for a profile id, invalid state") } - _, profile_id := GetVPNState().GoTransition(ASK_PROFILE, server.ProfilesRaw) - return profile_id, nil + GetVPNState().GoTransition(ASK_PROFILE, server.ProfilesRaw) + return nil } func (server *Server) GetConfig() (string, error) { @@ -139,14 +140,15 @@ func (server *Server) GetConfig() (string, error) { // Set the current profile if there is only one profile if len(server.Profiles.Info.ProfileList) == 1 { - return server.getConfigWithProfile(server.Profiles.Info.ProfileList[0].ID) + server.Profiles.Current = server.Profiles.Info.ProfileList[0].ID + return server.getConfigWithProfile() } - profile_id, profileErr := server.askForProfileID() + profileErr := server.askForProfileID() if profileErr != nil { return "", nil } - return server.getConfigWithProfile(profile_id) + return server.getConfigWithProfile() } diff --git a/src/server_test.go b/src/server_test.go index 3045983..6199884 100644 --- a/src/server_test.go +++ b/src/server_test.go @@ -35,13 +35,10 @@ func LoginOAuthSelenium(t *testing.T, url string) { } } -func StateCallback(t *testing.T, oldState string, newState string, data string) string { +func StateCallback(t *testing.T, oldState string, newState string, data string) { if newState == "OAuth_Started" { go LoginOAuthSelenium(t, data) } - - // We have no data to send back - return "" } func Test_server(t *testing.T) { @@ -50,8 +47,8 @@ func Test_server(t *testing.T) { // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) string { - return StateCallback(t, old, new, data) + state.Register("org.eduvpn.app.linux", "configstest", func(old string, new string, data string) { + StateCallback(t, old, new, data) }, false) _, configErr := state.Connect("https://eduvpnserver") @@ -68,7 +65,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) string { + state.Register("org.eduvpn.app.linux", "configsnologin", func(oldState string, newState string, data string) { if newState == "OAuth_Started" { baseURL := "http://127.0.0.1:8000/callback" url, err := HTTPConstructURL(baseURL, parameters) @@ -78,8 +75,6 @@ func test_connect_oauth_parameter(t *testing.T, parameters URLParameters, expect go http.Get(url) } - // We have no data to send back - return "" }, false) _, configErr := state.Connect("https://eduvpnserver") @@ -129,8 +124,8 @@ func Test_token_expired(t *testing.T) { // Do not verify because during testing, the cert is self-signed http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) string { - return StateCallback(t, old, new, data) + state.Register("org.eduvpn.app.linux", "configsexpired", func(old string, new string, data string) { + StateCallback(t, old, new, data) }, false) _, configErr := state.Connect("https://eduvpnserver") diff --git a/src/state.go b/src/state.go index b9bb920..6c740c1 100644 --- a/src/state.go +++ b/src/state.go @@ -6,10 +6,10 @@ import ( type VPNState struct { // Info passed by the client - ConfigDirectory string `json:"-"` - Name string `json:"-"` - StateCallback func(string, string, string) string `json:"-"` - StateCallbackData string `json:"-"` + ConfigDirectory string `json:"-"` + Name string `json:"-"` + StateCallback func(string, string, string) `json:"-"` + StateCallbackData string `json:"-"` // The chosen server Server *Server `json:"server"` @@ -27,7 +27,7 @@ type VPNState struct { Debug bool `json:"-"` } -func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string) string, debug bool) error { +func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error { state.InitializeFSM() if !state.InState(DEREGISTERED) { return errors.New("app already registered") diff --git a/src/wireguard.go b/src/wireguard.go index 3f7da40..2f1c41c 100644 --- a/src/wireguard.go +++ b/src/wireguard.go @@ -26,7 +26,8 @@ func wireguardConfigAddKey(config string, key wgtypes.Key) string { return interface_re.ReplaceAllString(config, to_replace) } -func (server *Server) WireguardGetConfig(profile_id string) (string, error) { +func (server *Server) WireguardGetConfig() (string, error) { + profile_id := server.Profiles.Current wireguardKey, wireguardErr := wireguardGenerateKey() if wireguardErr != nil { diff --git a/wrappers/python/eduvpncommon/__init__.py b/wrappers/python/eduvpncommon/__init__.py index 50af8eb..d7fce2f 100644 --- a/wrappers/python/eduvpncommon/__init__.py +++ b/wrappers/python/eduvpncommon/__init__.py @@ -38,7 +38,7 @@ lib.Deregister.argtypes, lib.Deregister.restype = [], None lib.Register.argtypes, lib.Register.restype = [c_char_p, c_char_p, VPNStateChange, c_int], c_void_p lib.GetOrganizationsList.argtypes, lib.GetOrganizationsList.restype = [], DataError lib.GetServersList.argtypes, lib.GetServersList.restype = [], DataError -lib.SendData.argtypes, lib.SendData.restype = [c_char_p], None +lib.SetProfileID.argtypes, lib.SetProfileID.restype = [c_char_p], None # We have to use c_void_p instead of c_char_p to free it properly # See https://stackoverflow.com/questions/13445568/python-ctypes-how-to-free-memory-getting-invalid-pointer-error lib.FreeString.argtypes, lib.FreeString.restype = [c_void_p], None diff --git a/wrappers/python/eduvpncommon/main.py b/wrappers/python/eduvpncommon/main.py index 8d2de9c..694fc59 100644 --- a/wrappers/python/eduvpncommon/main.py +++ b/wrappers/python/eduvpncommon/main.py @@ -47,8 +47,8 @@ def register_callback(eduvpn): ) -def SendData(data): - lib.SendData(data.encode("utf-8")) +def SetProfileID(profile_id): + lib.SetProfileID(profile_id.encode("utf-8")) class EduVPN(object): @@ -79,8 +79,8 @@ class EduVPN(object): def callback(self, old_state, new_state, data): self.event.run(old_state, new_state, data) - def send_data(self, data): - return SendData(data) + def set_profile(self, profile_id): + return SetProfileID(profile_id) class EventHandler(object): diff --git a/wrappers/python/main.py b/wrappers/python/main.py index bc29239..be9ab6c 100644 --- a/wrappers/python/main.py +++ b/wrappers/python/main.py @@ -14,7 +14,7 @@ def oauth_initialized(url): @_eduvpn.event.on("Ask_Profile", eduvpn.StateType.Enter) def ask_profile(profiles): print("ASK PROFILE CB", profiles) - _eduvpn.send_data("prefer-openvpn") + _eduvpn.set_profile("prefer-openvpn") success = _eduvpn.register(debug=True) |
