summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2023-05-02 13:06:50 +0200
committerJeroen Wijenbergh <46386452+jwijenbergh@users.noreply.github.com>2023-09-25 09:43:37 +0200
commitd0f4303ee5ccecc876fdfae1ce858369e3a323b7 (patch)
tree64bda3971398aa56f0672bab49a63a077bf6ae7c
parentdc7425beb85ea7d35e860a5df2bc0d8ddda8c28a (diff)
Client + FSM: Check transitions and add SetState
Also make sure GotConfig can be used to go back to
-rw-r--r--client/client.go93
-rw-r--r--client/fsm.go17
-rw-r--r--exports/exports.go11
-rw-r--r--internal/fsm/fsm.go19
-rw-r--r--wrappers/python/eduvpn_common/loader.py3
-rw-r--r--wrappers/python/eduvpn_common/main.py5
6 files changed, 105 insertions, 43 deletions
diff --git a/client/client.go b/client/client.go
index adc66d0..f0add42 100644
--- a/client/client.go
+++ b/client/client.go
@@ -162,6 +162,17 @@ func (c *Client) forwardTokens(srv server.Server) error {
return nil
}
+func (c *Client) goTransition(id fsm.StateID) error {
+ handled, err := c.FSM.GoTransition(id)
+ if err != nil {
+ return err
+ }
+ if !handled {
+ log.Logger.Debugf("transition not handled by the client: %s", GetStateName(id))
+ }
+ return nil
+}
+
// New creates a new client with the following parameters:
// - name: the name of the client
// - directory: the directory where the config files are stored. Absolute or relative
@@ -223,7 +234,10 @@ func (c *Client) Register() error {
if !c.FSM.InState(StateDeregistered) {
return errors.Errorf("fsm attempt to register while in '%v'", c.FSM.Current)
}
- c.FSM.GoTransition(StateNoServer)
+ err := c.goTransition(StateNoServer)
+ if err != nil {
+ return errors.WrapPrefix(err, "failed to register", 0)
+ }
return nil
}
@@ -340,9 +354,9 @@ func (c *Client) locationCallback(ck *cookie.Cookie) error {
if err != nil {
return err
}
- t := c.FSM.GoTransition(StateChosenLocation)
- if !t {
- log.Logger.Warningf("transition chosen location not completed")
+ err = c.goTransition(StateChosenLocation)
+ if err != nil {
+ return err
}
return nil
}
@@ -372,16 +386,16 @@ func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool)
}
}
- t := c.FSM.GoTransition(StateChosenServer)
- if !t {
- log.Logger.Warningf("transition not completed for chosen server")
+ err := c.goTransition(StateChosenServer)
+ if err != nil {
+ return err
}
// oauth
// TODO: This should be ck.Context()
// But needsrelogin needs a rewrite to support this properly
// first make sure we get the most up to date tokens from the client
- err := c.updateTokens(srv)
+ err = c.updateTokens(srv)
if err != nil {
log.Logger.Debugf("failed to get tokens from client: %v", err)
}
@@ -396,9 +410,9 @@ func (c *Client) callbacks(ck *cookie.Cookie, srv server.Server, forceauth bool)
return err
}
}
- t = c.FSM.GoTransition(StateAuthorized)
- if !t {
- log.Logger.Warningf("transition authorized not completed")
+ err = c.goTransition(StateAuthorized)
+ if err != nil {
+ return err
}
return nil
@@ -434,37 +448,34 @@ func (c *Client) profileCallback(ck *cookie.Cookie, srv server.Server) error {
return err
}
}
- t := c.FSM.GoTransition(StateChosenProfile)
- if !t {
- log.Logger.Warningf("transition chosen profile not completed")
+ err = c.goTransition(StateChosenProfile)
+ if err != nil {
+ return err
}
return nil
}
// AddServer adds a server with identifier and type
func (c *Client) AddServer(ck *cookie.Cookie, identifier string, _type srvtypes.Type, ni bool) (err error) {
+
// If we have failed to add the server, we remove it again
// We add the server because we can then obtain it in other callback functions
+ previousState := c.FSM.Current
defer func() {
if err != nil {
_ = c.RemoveServer(identifier, _type) //nolint:errcheck
}
- if !ni {
- c.FSM.GoTransition(StateNoServer)
+ // If we must run callbacks, go to the previous state if we're not in it
+ if !ni && !c.FSM.InState(previousState) {
+ c.FSM.GoTransition(previousState)
}
}()
if !ni {
- // Try to go to no server
- c.FSM.GoTransition(StateNoServer)
-
- // If the transition was not successful, log
- if !c.FSM.InState(StateNoServer) {
- return errors.Errorf("wrong state to add a server: %s", GetStateName(c.FSM.Current))
- }
- t := c.FSM.GoTransition(StateLoadingServer)
- if !t {
- log.Logger.Warningf("transition not completed for loading server")
+ // This only returns a boolean if it was handled
+ err = c.goTransition(StateLoadingServer)
+ if err != nil {
+ return err
}
}
@@ -533,9 +544,9 @@ func (c *Client) config(ck *cookie.Cookie, srv server.Server, pTCP bool, forceAu
return nil, err
}
- t := c.FSM.GoTransition(StateRequestConfig)
- if !t {
- log.Logger.Warningf("transition not completed for requesting config")
+ err = c.goTransition(StateRequestConfig)
+ if err != nil {
+ return nil, err
}
err = c.profileCallback(ck, srv)
@@ -577,12 +588,15 @@ func (c *Client) server(identifier string, _type srvtypes.Type) (srv server.Serv
// GetConfig gets a VPN configuration
func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.Type, pTCP bool) (cfg *srvtypes.Configuration, err error) {
+ previousState := c.FSM.Current
defer func() {
if err == nil {
c.FSM.GoTransition(StateGotConfig)
} else {
- // go back if an error occurred
- c.FSM.GoTransition(StateNoServer)
+ if !c.FSM.InState(previousState) {
+ // go back to the previous state if an error occurred
+ c.FSM.GoTransition(previousState)
+ }
}
}()
if _type != srvtypes.TypeSecureInternet {
@@ -591,9 +605,9 @@ func (c *Client) GetConfig(ck *cookie.Cookie, identifier string, _type srvtypes.
return nil, err
}
}
- t := c.FSM.GoTransition(StateLoadingServer)
- if !t {
- log.Logger.Warningf("transition not completed for loading server")
+ err = c.goTransition(StateLoadingServer)
+ if err != nil {
+ return nil, err
}
srv, set, err := c.server(identifier, _type)
if err != nil {
@@ -862,3 +876,14 @@ func (c *Client) StartFailover(ck *cookie.Cookie, gateway string, mtu int, readR
return f.Start(ck.Context(), gateway, mtu)
}
+
+
+func (c *Client) SetState(state FSMStateID) error {
+ err := c.FSM.CheckTransition(state)
+ if err != nil {
+ return err
+ }
+ // TODO: Now we don't pass any data :/
+ c.FSM.GoTransition(state)
+ return nil
+}
diff --git a/client/fsm.go b/client/fsm.go
index 038e6cf..9c83b9b 100644
--- a/client/fsm.go
+++ b/client/fsm.go
@@ -98,12 +98,14 @@ func newFSM(
Transitions: []FSMTransition{
{To: StateChosenLocation, Description: "Location chosen"},
{To: StateNoServer, Description: "Go back or Error"},
+ {To: StateGotConfig, Description: "Go back or Error"},
},
},
StateChosenLocation: FSMState{
Transitions: []FSMTransition{
{To: StateChosenServer, Description: "Server has been chosen"},
{To: StateNoServer, Description: "Go back or Error"},
+ {To: StateGotConfig, Description: "Go back or Error"},
},
},
StateLoadingServer: FSMState{
@@ -114,45 +116,54 @@ func newFSM(
Description: "User chooses a Secure Internet server but no location is configured",
},
{To: StateNoServer, Description: "Go back or Error"},
+ {To: StateGotConfig, Description: "Go back or Error"},
},
},
StateChosenServer: FSMState{
Transitions: []FSMTransition{
{To: StateAuthorized, Description: "Found tokens in config"},
{To: StateOAuthStarted, Description: "No tokens found in config"},
+ {To: StateNoServer, Description: "Go back or Error"},
+ {To: StateGotConfig, Description: "Go back or Error"},
},
},
StateOAuthStarted: FSMState{
Transitions: []FSMTransition{
{To: StateAuthorized, Description: "User authorizes with browser"},
{To: StateNoServer, Description: "Go back or Error"},
+ {To: StateGotConfig, Description: "Go back or Error"},
},
},
StateAuthorized: FSMState{
Transitions: []FSMTransition{
{To: StateOAuthStarted, Description: "Re-authorize with OAuth"},
{To: StateRequestConfig, Description: "Client requests a config"},
- {To: StateNoServer, Description: "Client wants to go back to the main screen"},
+ {To: StateNoServer, Description: "Go back or Error"},
+ {To: StateGotConfig, Description: "Go back or Error"},
},
},
StateRequestConfig: FSMState{
Transitions: []FSMTransition{
{To: StateAskProfile, Description: "Multiple profiles found and no profile chosen"},
{To: StateChosenProfile, Description: "Only one profile or profile already chosen"},
- {To: StateNoServer, Description: "Cancel or Error"},
{To: StateOAuthStarted, Description: "Re-authorize"},
+ {To: StateNoServer, Description: "Go back or Error"},
+ {To: StateGotConfig, Description: "Go back or Error"},
},
},
StateAskProfile: FSMState{
Transitions: []FSMTransition{
{To: StateNoServer, Description: "Cancel or Error"},
{To: StateChosenProfile, Description: "Profile has been chosen"},
+ {To: StateNoServer, Description: "Go back or Error"},
+ {To: StateGotConfig, Description: "Go back or Error"},
},
},
StateChosenProfile: FSMState{
Transitions: []FSMTransition{
- {To: StateNoServer, Description: "Cancel or Error"},
{To: StateGotConfig, Description: "Config has been obtained"},
+ {To: StateNoServer, Description: "Go back or Error"},
+ {To: StateGotConfig, Description: "Go back or Error"},
},
},
StateGotConfig: FSMState{
diff --git a/exports/exports.go b/exports/exports.go
index ae741d7..5e62e4d 100644
--- a/exports/exports.go
+++ b/exports/exports.go
@@ -512,6 +512,17 @@ func StartFailover(c C.uintptr_t, gateway *C.char, mtu C.int, readRxBytes C.Read
return droppedC, nil
}
+// SetState sets the state of the statemachine
+// Note that this transitions the FSM into the new state without passing any data to it
+//export SetState
+func SetState(fsmState C.int) *C.char {
+ state, stateErr := getVPNState()
+ if stateErr != nil {
+ return getCError(stateErr)
+ }
+ return getCError(state.SetState(client.FSMStateID(fsmState)))
+}
+
// FreeString frees a string that was allocated by the eduvpn-common Go library
// This happens when we return strings, such as errors from the Go lib back to the client
// The client MUST thus ensure that this memory is freed using this function
diff --git a/internal/fsm/fsm.go b/internal/fsm/fsm.go
index 3cc8edb..cae8833 100644
--- a/internal/fsm/fsm.go
+++ b/internal/fsm/fsm.go
@@ -131,7 +131,14 @@ func (fsm *FSM) writeGraph() {
// If this transition is not handled by the client, it returns an error.
func (fsm *FSM) GoTransitionRequired(newState StateID, data interface{}) error {
oldState := fsm.Current
- if !fsm.GoTransitionWithData(newState, data) {
+
+ handled, err := fsm.GoTransitionWithData(newState, data)
+ // transition ios not possible
+ if err != nil {
+ return err
+ }
+ // transition is not handled
+ if !handled {
return errors.Errorf("fsm failed transition from '%v' to '%v', is this required transition handled?", fsm.GetStateName(oldState), fsm.GetStateName(newState))
}
return nil
@@ -139,9 +146,9 @@ func (fsm *FSM) GoTransitionRequired(newState StateID, data interface{}) error {
// GoTransitionWithData is a helper that transitions the state machine toward the 'newState' with associated state data 'data'
// It returns whether or not the transition is handled by the client.
-func (fsm *FSM) GoTransitionWithData(newState StateID, data interface{}) bool {
- if fsm.CheckTransition(newState) != nil {
- return false
+func (fsm *FSM) GoTransitionWithData(newState StateID, data interface{}) (bool, error) {
+ if err := fsm.CheckTransition(newState); err != nil {
+ return false, err
}
prev := fsm.Current
@@ -150,11 +157,11 @@ func (fsm *FSM) GoTransitionWithData(newState StateID, data interface{}) bool {
fsm.writeGraph()
}
- return fsm.StateCallback(prev, newState, data)
+ return fsm.StateCallback(prev, newState, data), nil
}
// GoTransition is an alias to call GoTransitionWithData but have an empty string as data.
-func (fsm *FSM) GoTransition(newState StateID) bool {
+func (fsm *FSM) GoTransition(newState StateID) (bool, error) {
// No data means the callback is never required
return fsm.GoTransitionWithData(newState, "")
}
diff --git a/wrappers/python/eduvpn_common/loader.py b/wrappers/python/eduvpn_common/loader.py
index 1a172af..afad569 100644
--- a/wrappers/python/eduvpn_common/loader.py
+++ b/wrappers/python/eduvpn_common/loader.py
@@ -102,6 +102,9 @@ def initialize_functions(lib: CDLL) -> None:
lib.SetSupportWireguard.argtypes, lib.SetSupportWireguard.restype = [
c_int,
], c_void_p
+ lib.SetState.argtypes, lib.SetState.restype = [
+ c_int,
+ ], c_void_p
lib.StartFailover.argtypes, lib.StartFailover.restype = [
c_int,
c_char_p,
diff --git a/wrappers/python/eduvpn_common/main.py b/wrappers/python/eduvpn_common/main.py
index 5608562..64fa6ac 100644
--- a/wrappers/python/eduvpn_common/main.py
+++ b/wrappers/python/eduvpn_common/main.py
@@ -175,6 +175,11 @@ class EduVPN(object):
if remove_err:
forwardError(remove_err)
+ def set_state(self, state: int):
+ state_err = self.go_function(self.lib.SetState, state)
+ if state_err:
+ forwardError(state_err)
+
def get_config(
self, _type: ServerType, identifier: str, prefer_tcp: bool = False
) -> str: