summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--client/client.go10
-rw-r--r--client/client_test.go18
-rw-r--r--client/fsm.go2
-rw-r--r--client/server.go13
-rw-r--r--cmd/cli/main.go3
-rw-r--r--exports/exports.go17
-rw-r--r--internal/fsm/fsm.go21
-rw-r--r--wrappers/python/eduvpn_common/event.py15
-rw-r--r--wrappers/python/eduvpn_common/main.py14
-rw-r--r--wrappers/python/eduvpn_common/types.py2
10 files changed, 77 insertions, 38 deletions
diff --git a/client/client.go b/client/client.go
index 103136b..63e4f91 100644
--- a/client/client.go
+++ b/client/client.go
@@ -73,7 +73,7 @@ func (client *Client) Register(
name string,
directory string,
language string,
- stateCallback func(FSMStateID, FSMStateID, interface{}),
+ stateCallback func(FSMStateID, FSMStateID, interface{}) bool,
debug bool,
) error {
errorMessage := "failed to register with the GO library"
@@ -155,11 +155,15 @@ func (client *Client) Deregister() {
// askProfile asks the user for a profile by moving the FSM to the ASK_PROFILE state.
func (client *Client) askProfile(chosenServer server.Server) error {
+ errorMessage := "failed asking for profiles"
profiles, profilesErr := server.GetValidProfiles(chosenServer, client.SupportsWireguard)
if profilesErr != nil {
- return types.NewWrappedError("failed asking for profiles", profilesErr)
+ return types.NewWrappedError(errorMessage, profilesErr)
+ }
+ goTransitionErr := client.FSM.GoTransitionRequired(STATE_ASK_PROFILE, profiles)
+ if goTransitionErr != nil {
+ return types.NewWrappedError(errorMessage, goTransitionErr)
}
- client.FSM.GoTransitionWithData(STATE_ASK_PROFILE, profiles)
return nil
}
diff --git a/client/client_test.go b/client/client_test.go
index f386c3c..adf5c75 100644
--- a/client/client_test.go
+++ b/client/client_test.go
@@ -81,8 +81,9 @@ func Test_server(t *testing.T) {
"org.letsconnect-vpn.app.linux",
"configstest",
"en",
- func(old FSMStateID, new FSMStateID, data interface{}) {
+ func(old FSMStateID, new FSMStateID, data interface{}) bool {
stateCallback(t, old, new, data, state)
+ return true
},
false,
)
@@ -113,7 +114,7 @@ func test_connect_oauth_parameter(
"org.letsconnect-vpn.app.linux",
configDirectory,
"en",
- func(oldState FSMStateID, newState FSMStateID, data interface{}) {
+ func(oldState FSMStateID, newState FSMStateID, data interface{}) bool {
if newState == STATE_OAUTH_STARTED {
server, serverErr := state.Servers.GetCustomServer(serverURI)
if serverErr != nil {
@@ -142,6 +143,7 @@ func test_connect_oauth_parameter(
}
}()
}
+ return true
},
false,
)
@@ -219,8 +221,9 @@ func Test_token_expired(t *testing.T) {
"org.letsconnect-vpn.app.linux",
"configsexpired",
"en",
- func(old FSMStateID, new FSMStateID, data interface{}) {
+ func(old FSMStateID, new FSMStateID, data interface{}) bool {
stateCallback(t, old, new, data, state)
+ return true
},
false,
)
@@ -279,8 +282,9 @@ func Test_token_invalid(t *testing.T) {
"org.letsconnect-vpn.app.linux",
"configsinvalid",
"en",
- func(old FSMStateID, new FSMStateID, data interface{}) {
+ func(old FSMStateID, new FSMStateID, data interface{}) bool {
stateCallback(t, old, new, data, state)
+ return true
},
false,
)
@@ -336,8 +340,9 @@ func Test_invalid_profile_corrected(t *testing.T) {
"org.letsconnect-vpn.app.linux",
"configscancelprofile",
"en",
- func(old FSMStateID, new FSMStateID, data interface{}) {
+ func(old FSMStateID, new FSMStateID, data interface{}) bool {
stateCallback(t, old, new, data, state)
+ return true
},
false,
)
@@ -393,8 +398,9 @@ func Test_prefer_tcp(t *testing.T) {
"org.letsconnect-vpn.app.linux",
"configsprefertcp",
"en",
- func(old FSMStateID, new FSMStateID, data interface{}) {
+ func(old FSMStateID, new FSMStateID, data interface{}) bool {
stateCallback(t, old, new, data, state)
+ return true
},
false,
)
diff --git a/client/fsm.go b/client/fsm.go
index 767ceaa..c93a9a8 100644
--- a/client/fsm.go
+++ b/client/fsm.go
@@ -96,7 +96,7 @@ func GetStateName(s FSMStateID) string {
}
func newFSM(
- callback func(FSMStateID, FSMStateID, interface{}),
+ callback func(FSMStateID, FSMStateID, interface{}) bool,
directory string,
debug bool,
) fsm.FSM {
diff --git a/client/server.go b/client/server.go
index 3bcecc6..d22dc65 100644
--- a/client/server.go
+++ b/client/server.go
@@ -479,16 +479,20 @@ func (client *Client) GetConfigCustomServer(url string, preferTCP bool) (string,
// askSecureLocation asks the user to choose a Secure Internet location by moving the FSM to the STATE_ASK_LOCATION state.
func (client *Client) askSecureLocation() error {
+ errorMessage := "failed settings secure location"
locations := client.Discovery.GetSecureLocationList()
// Ask for the location in the callback
- client.FSM.GoTransitionWithData(STATE_ASK_LOCATION, locations)
+ goTransitionErr := client.FSM.GoTransitionRequired(STATE_ASK_LOCATION, locations)
+ if goTransitionErr != nil {
+ return types.NewWrappedError(errorMessage, goTransitionErr)
+ }
// The state has changed, meaning setting the secure location was not successful
if client.FSM.Current != STATE_ASK_LOCATION {
// TODO: maybe a custom type for this errors.new?
return types.NewWrappedError(
- "failed setting secure location",
+ errorMessage,
errors.New("failed loading secure location"),
)
}
@@ -579,7 +583,10 @@ func (client *Client) ensureLogin(chosenServer server.Server) error {
if server.NeedsRelogin(chosenServer) {
url, urlErr := server.GetOAuthURL(chosenServer, client.Name)
- client.FSM.GoTransitionWithData(STATE_OAUTH_STARTED, url)
+ goTransitionErr := client.FSM.GoTransitionRequired(STATE_OAUTH_STARTED, url)
+ if goTransitionErr != nil {
+ return types.NewWrappedError(errorMessage, goTransitionErr)
+ }
if urlErr != nil {
client.goBackInternal()
diff --git a/cmd/cli/main.go b/cmd/cli/main.go
index 97d0f39..96ca42f 100644
--- a/cmd/cli/main.go
+++ b/cmd/cli/main.go
@@ -125,8 +125,9 @@ func printConfig(url string, serverType ServerTypes) {
"org.eduvpn.app.linux",
"configs",
"en",
- func(old client.FSMStateID, new client.FSMStateID, data interface{}) {
+ func(old client.FSMStateID, new client.FSMStateID, data interface{}) bool {
stateCallback(state, old, new, data)
+ return true
},
true,
)
diff --git a/exports/exports.go b/exports/exports.go
index 05f2462..8e418d4 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -4,11 +4,11 @@ package main
#include <stdlib.h>
#include "error.h"
-typedef void (*PythonCB)(const char* name, int oldstate, int newstate, void* data);
+typedef int (*PythonCB)(const char* name, int oldstate, int newstate, void* data);
-static void call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data)
+static int call_callback(PythonCB callback, const char *name, int oldstate, int newstate, void* data)
{
- callback(name, oldstate, newstate, data);
+ return callback(name, oldstate, newstate, data);
}
*/
import "C"
@@ -61,18 +61,19 @@ func StateCallback(
old_state client.FSMStateID,
new_state client.FSMStateID,
data interface{},
-) {
+) bool {
P_StateCallback, exists := P_StateCallbacks[name]
if !exists || P_StateCallback == nil {
- return
+ return false
}
name_c := C.CString(name)
oldState_c := C.int(old_state)
newState_c := C.int(new_state)
data_c := GetStateData(state, new_state, data)
- C.call_callback(P_StateCallback, name_c, oldState_c, newState_c, data_c)
+ handled := C.call_callback(P_StateCallback, name_c, oldState_c, newState_c, data_c)
C.free(unsafe.Pointer(name_c))
// data_c gets freed by the wrapper
+ return handled == C.int(1)
}
func GetVPNState(name string) (*client.Client, error) {
@@ -110,8 +111,8 @@ func Register(
nameStr,
C.GoString(config_directory),
C.GoString(language),
- func(old client.FSMStateID, new client.FSMStateID, data interface{}) {
- StateCallback(state, nameStr, old, new, data)
+ func(old client.FSMStateID, new client.FSMStateID, data interface{}) bool {
+ return StateCallback(state, nameStr, old, new, data)
},
debug != 0,
)
diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go
index 198d51a..1b45ce8 100644
--- a/internal/fsm/fsm.go
+++ b/internal/fsm/fsm.go
@@ -6,6 +6,7 @@ import (
"os/exec"
"path"
"sort"
+ "github.com/eduvpn/eduvpn-common/types"
)
type (
@@ -47,7 +48,7 @@ type FSM struct {
// Info to be passed from the parent state
Name string
- StateCallback func(FSMStateID, FSMStateID, interface{})
+ StateCallback func(FSMStateID, FSMStateID, interface{}) bool
Directory string
Debug bool
GetName func(FSMStateID) string
@@ -56,7 +57,7 @@ type FSM struct {
func (fsm *FSM) Init(
current FSMStateID,
states map[FSMStateID]FSMState,
- callback func(FSMStateID, FSMStateID, interface{}),
+ callback func(FSMStateID, FSMStateID, interface{}) bool,
directory string,
nameGen func(FSMStateID) string,
debug bool,
@@ -110,9 +111,18 @@ func (fsm *FSM) GoBack() {
fsm.GoTransition(fsm.States[fsm.Current].BackState)
}
+func (fsm *FSM) GoTransitionRequired(newState FSMStateID, data interface{}) error {
+ oldState := fsm.Current
+ if !fsm.GoTransitionWithData(newState, data) {
+ return types.NewWrappedError("failed required transition", fmt.Errorf("required transition not handled, from: %s -> to: %s", fsm.GetName(oldState), fsm.GetName(newState)))
+ }
+ return nil
+}
+
func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool {
ok := fsm.HasTransition(newState)
+ handled := false
if ok {
oldState := fsm.Current
fsm.Current = newState
@@ -120,14 +130,15 @@ func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool
fsm.writeGraph()
}
- fsm.StateCallback(oldState, newState, data)
+ handled = fsm.StateCallback(oldState, newState, data)
}
- return ok
+ return handled
}
func (fsm *FSM) GoTransition(newState FSMStateID) bool {
- return fsm.GoTransitionWithData(newState, "{}")
+ // No data means the callback is never required
+ return fsm.GoTransitionWithData(newState, "")
}
func (fsm *FSM) generateMermaidGraph() string {
diff --git a/wrappers/python/eduvpn_common/event.py b/wrappers/python/eduvpn_common/event.py
index 1823130..4387222 100644
--- a/wrappers/python/eduvpn_common/event.py
+++ b/wrappers/python/eduvpn_common/event.py
@@ -147,7 +147,7 @@ class EventHandler(object):
def run_state(
self, state: int, other_state: int, state_type: StateType, data: str
- ) -> None:
+ ) -> bool:
"""The function that runs the callback for a specific event
:param state: int: The state of the event
@@ -158,13 +158,14 @@ class EventHandler(object):
:meta private:
"""
if (state, state_type) not in self.handlers:
- return
+ return False
for func in self.handlers[(state, state_type)]:
func(other_state, data)
+ return True
def run(
self, old_state: int, new_state: int, data: Any, convert: bool = True
- ) -> None:
+ ) -> bool:
"""Run a specific event.
It converts the data and then runs the event for all state types
@@ -180,5 +181,9 @@ class EventHandler(object):
if convert:
converted = convert_data(self.lib, new_state, data)
self.run_state(old_state, new_state, StateType.LEAVE, converted)
- self.run_state(new_state, old_state, StateType.ENTER, converted)
- self.run_state(new_state, old_state, StateType.WAIT, converted)
+ # We decide handled based on enter transitions
+ handled = self.run_state(new_state, old_state, StateType.ENTER, converted)
+ # Only run wait transitions if the enter transition is handled
+ if handled:
+ self.run_state(new_state, old_state, StateType.WAIT, converted)
+ return handled
diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py
index fff7e31..20d646f 100644
--- a/wrappers/python/eduvpn_common/main.py
+++ b/wrappers/python/eduvpn_common/main.py
@@ -1,4 +1,5 @@
import threading
+from ctypes import c_int
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from eduvpn_common.discovery import DiscoOrganizations, DiscoServers, get_disco_organizations, get_disco_servers
@@ -371,7 +372,7 @@ class EduVPN(object):
"""
return self.event_handler
- def callback(self, old_state: State, new_state: State, data: Any) -> None:
+ def callback(self, old_state: State, new_state: State, data: Any) -> bool:
"""Run an event callback
:param old_state: State: The previous state
@@ -379,7 +380,7 @@ class EduVPN(object):
:param data: Any: The data to pass to the event
"""
- self.event.run(old_state, new_state, data)
+ return self.event.run(old_state, new_state, data)
def set_profile(self, profile_id: str) -> None:
"""Set the profile of the current server
@@ -506,7 +507,7 @@ eduvpn_objects: Dict[str, EduVPN] = {}
@VPNStateChange
-def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> None:
+def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> int:
"""The internal callback that is passed to the Go library
:param name: bytes: The name of the client
@@ -518,8 +519,11 @@ def state_callback(name: bytes, old_state: int, new_state: int, data: Any) -> No
"""
name_decoded = name.decode()
if name_decoded not in eduvpn_objects:
- return
- eduvpn_objects[name_decoded].callback(State(old_state), State(new_state), data)
+ return 0
+ handled = eduvpn_objects[name_decoded].callback(State(old_state), State(new_state), data)
+ if handled:
+ return 1
+ return 0
def add_as_global_object(eduvpn: EduVPN) -> bool:
diff --git a/wrappers/python/eduvpn_common/types.py b/wrappers/python/eduvpn_common/types.py
index a6eda43..07a02d3 100644
--- a/wrappers/python/eduvpn_common/types.py
+++ b/wrappers/python/eduvpn_common/types.py
@@ -165,7 +165,7 @@ class ConfigError(Structure):
# The type for a Go state change callback
-VPNStateChange = CFUNCTYPE(None, c_char_p, c_int, c_int, c_void_p)
+VPNStateChange = CFUNCTYPE(c_int, c_char_p, c_int, c_int, c_void_p)
def encode_args(args: List[Any], types: List[Any]) -> Iterator[Any]: