summaryrefslogtreecommitdiff
path: root/state.go
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-08-19 16:32:35 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-08-19 16:32:35 +0200
commitd8c7f962e4fe2d4a46f0aeb1c9d9a371d5e41ee0 (patch)
treed98c682b31cccb975483111e5b817d5c8d029838 /state.go
parentf81d05226fe61b697baa91e926dd86efad9d8084 (diff)
State + FSM: Properly handle the disconnect flow
- /disconnect is now called - A new state is added (DISCONNECTING) that waits for the disconnect to complete - A helper function is exposed (InFSMState) that can be used by clients to see in which state they are in
Diffstat (limited to 'state.go')
-rw-r--r--state.go54
1 files changed, 40 insertions, 14 deletions
diff --git a/state.go b/state.go
index 3ad979d..139f67a 100644
--- a/state.go
+++ b/state.go
@@ -14,7 +14,7 @@ import (
)
type ServerInfo = server.ServerInfoScreen
-type VPNStateID = fsm.FSMStateID
+type StateID = fsm.FSMStateID
type VPNState struct {
// The chosen server
@@ -40,9 +40,9 @@ func (state *VPNState) GetSavedServers() *server.ServersConfiguredScreen {
return state.Servers.GetServersConfigured()
}
-func (state *VPNState) Register(name string, directory string, stateCallback func(VPNStateID, VPNStateID, interface{}), debug bool) error {
+func (state *VPNState) Register(name string, directory string, stateCallback func(StateID, StateID, interface{}), debug bool) error {
errorMessage := "failed to register with the GO library"
- if !state.FSM.InState(fsm.DEREGISTERED) {
+ if !state.InFSMState(fsm.DEREGISTERED) {
return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.DeregisteredError{}.CustomError()}
}
// Initialize the logger
@@ -92,7 +92,7 @@ func (state *VPNState) Deregister() error {
func (state *VPNState) GoBack() error {
errorMessage := "failed to go back"
- if state.FSM.InState(fsm.DEREGISTERED) {
+ if state.InFSMState(fsm.DEREGISTERED) {
return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.DeregisteredError{}.CustomError()}
}
@@ -104,7 +104,7 @@ func (state *VPNState) GoBack() error {
func (state *VPNState) getConfig(chosenServer server.Server, forceTCP bool) (string, string, error) {
errorMessage := "failed to get a configuration for OpenVPN/Wireguard"
- if state.FSM.InState(fsm.DEREGISTERED) {
+ if state.InFSMState(fsm.DEREGISTERED) {
return "", "", &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.DeregisteredError{}.CustomError()}
}
@@ -273,7 +273,7 @@ func (state *VPNState) GetConfigCustomServer(url string, forceTCP bool) (string,
func (state *VPNState) CancelOAuth() error {
errorMessage := "failed to cancel OAuth"
- if !state.FSM.InState(fsm.OAUTH_STARTED) {
+ if !state.InFSMState(fsm.OAUTH_STARTED) {
return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: state.FSM.Current, Want: fsm.OAUTH_STARTED}.CustomError()}
}
@@ -289,7 +289,7 @@ func (state *VPNState) CancelOAuth() error {
func (state *VPNState) ChangeSecureLocation() error {
errorMessage := "failed to change location from the main screen"
- if !state.FSM.InState(fsm.NO_SERVER) {
+ if !state.InFSMState(fsm.NO_SERVER) {
return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateError{Got: state.FSM.Current, Want: fsm.NO_SERVER}.CustomError()}
}
@@ -306,14 +306,14 @@ func (state *VPNState) ChangeSecureLocation() error {
}
func (state *VPNState) GetDiscoOrganizations() (string, error) {
- if state.FSM.InState(fsm.DEREGISTERED) {
+ if state.InFSMState(fsm.DEREGISTERED) {
return "", &types.WrappedErrorMessage{Message: "failed to get the organizations with Discovery", Err: fsm.DeregisteredError{}.CustomError()}
}
return state.Discovery.GetOrganizationsList()
}
func (state *VPNState) GetDiscoServers() (string, error) {
- if state.FSM.InState(fsm.DEREGISTERED) {
+ if state.InFSMState(fsm.DEREGISTERED) {
return "", &types.WrappedErrorMessage{Message: "failed to get the servers with Discovery", Err: fsm.DeregisteredError{}.CustomError()}
}
return state.Discovery.GetServersList()
@@ -351,7 +351,7 @@ func (state *VPNState) getServerInfoData() *server.ServerInfoScreen {
}
func (state *VPNState) SetConnected() error {
- if state.FSM.InState(fsm.CONNECTED) {
+ if state.InFSMState(fsm.CONNECTED) {
// already connected, show no error
return nil
}
@@ -364,7 +364,7 @@ func (state *VPNState) SetConnected() error {
}
func (state *VPNState) SetConnecting() error {
- if state.FSM.InState(fsm.CONNECTING) {
+ if state.InFSMState(fsm.CONNECTING) {
// already loading connection, show no error
return nil
}
@@ -376,15 +376,37 @@ func (state *VPNState) SetConnecting() error {
return nil
}
+func (state *VPNState) SetDisconnecting() error {
+ if state.InFSMState(fsm.DISCONNECTING) {
+ // already disconnecting, show no error
+ return nil
+ }
+ if !state.FSM.HasTransition(fsm.DISCONNECTING) {
+ return &types.WrappedErrorMessage{Message: "failed to set disconnecting", Err: fsm.WrongStateTransitionError{Got: state.FSM.Current, Want: fsm.DISCONNECTING}.CustomError()}
+ }
+
+
+ state.FSM.GoTransitionWithData(fsm.DISCONNECTING, state.getServerInfoData(), false)
+ return nil
+}
+
func (state *VPNState) SetDisconnected() error {
- if state.FSM.InState(fsm.HAS_CONFIG) {
+ errorMessage := "failed to set disconnected"
+ if state.InFSMState(fsm.HAS_CONFIG) {
// already disconnected, show no error
return nil
}
if !state.FSM.HasTransition(fsm.HAS_CONFIG) {
- return &types.WrappedErrorMessage{Message: "failed to set disconnected", Err: fsm.WrongStateTransitionError{Got: state.FSM.Current, Want: fsm.HAS_CONFIG}.CustomError()}
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: fsm.WrongStateTransitionError{Got: state.FSM.Current, Want: fsm.HAS_CONFIG}.CustomError()}
}
+ // Do the /disconnect API call and go to disconnected after...
+ currentServer, currentServerErr := state.Servers.GetCurrentServer()
+ if currentServerErr != nil {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: currentServerErr}
+ }
+ server.Disconnect(currentServer)
+
state.FSM.GoTransitionWithData(fsm.HAS_CONFIG, state.getServerInfoData(), false)
return nil
@@ -417,7 +439,7 @@ func (state *VPNState) RenewSession() error {
}
func (state *VPNState) ShouldRenewButton() bool {
- if !state.FSM.InState(fsm.CONNECTED) {
+ if !state.InFSMState(fsm.CONNECTED) {
return false
}
@@ -431,6 +453,10 @@ func (state *VPNState) ShouldRenewButton() bool {
return server.ShouldRenewButton(currentServer)
}
+func (state *VPNState) InFSMState(checkState StateID) bool {
+ return state.FSM.InState(checkState)
+}
+
func GetErrorCause(err error) error {
return types.GetErrorCause(err)
}