summaryrefslogtreecommitdiff
path: root/internal/fsm/fsm.go
blob: e6f3f3a76c63230b7b16974e525c239c87690de0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
// Package fsm defines a finite state machine and has the ability to save this state machine to a graph file
// This graph file can be visualized using mermaid.js
package fsm

import (
	"fmt"
	"os"
	"os/exec"
	"path"
	"sort"
	"github.com/eduvpn/eduvpn-common/types"
)

type (
	//StateID represents the Identifier of the state.
	FSMStateID      int8
	//StateIDSlice represents the list of state identifiers.
	FSMStateIDSlice []FSMStateID
)

func (v FSMStateIDSlice) Len() int {
	return len(v)
}

func (v FSMStateIDSlice) Less(i, j int) bool {
	return v[i] < v[j]
}

func (v FSMStateIDSlice) Swap(i, j int) {
	v[i], v[j] = v[j], v[i]
}

// Transition indicates an arrow in the state graph.
type FSMTransition struct {
	// To represents the to-be-new state
	To          FSMStateID
	// Description is what type of message the arrow gets in the graph
	Description string
}

type (
	FSMStates map[FSMStateID]FSMState
)

// State represents a single node in the graph.
type FSMState struct {
	// Transitions indicates which out arrows this node has
	Transitions []FSMTransition
}

// FSM represents the total graph.
type FSM struct {
	// States is the map from state ID to states
	States  FSMStates

	// Current is the current state represented by the identifier
	Current FSMStateID

	// Name represents the descriptive name of this state machine
	Name          string

	// StateCallback is the function ran when a transition occurs
	// It takes the old state, the new state and the data and returns if this is handled by the client
	StateCallback func(FSMStateID, FSMStateID, interface{}) bool

	// Directory represents the path where the state graph is stored
	Directory     string

	// Generate represents whether we want to generate the graph
	Generate         bool

	// GetStateName gets the name of a state as a string
	GetStateName       func(FSMStateID) string
}

// Init initializes the state machine and sets it to the given current state.
func (fsm *FSM) Init(
	current FSMStateID,
	states map[FSMStateID]FSMState,
	callback func(FSMStateID, FSMStateID, interface{}) bool,
	directory string,
	nameGen func(FSMStateID) string,
	generate bool,
) {
	fsm.States = states
	fsm.Current = current
	fsm.StateCallback = callback
	fsm.Directory = directory
	fsm.GetStateName = nameGen
	fsm.Generate = generate
}

// InState returns whether or not the state machine is in the given 'check' state.
func (fsm *FSM) InState(check FSMStateID) bool {
	return check == fsm.Current
}

// HasTransition checks whether or not the state machine has a transition to the given 'check' state.
func (fsm *FSM) HasTransition(check FSMStateID) bool {
	for _, transitionState := range fsm.States[fsm.Current].Transitions {
		if transitionState.To == check {
			return true
		}
	}

	return false
}

// graphFilename gets the full path to the graph filename including the .graph extension.
func (fsm *FSM) graphFilename(extension string) string {
	debugPath := path.Join(fsm.Directory, "graph")
	return fmt.Sprintf("%s%s", debugPath, extension)
}

// writeGraph writes the state machine to a .graph file.
func (fsm *FSM) writeGraph() {
	graph := fsm.GenerateGraph()
	graphFile := fsm.graphFilename(".graph")
	graphImgFile := fsm.graphFilename(".png")
	f, err := os.Create(graphFile)
	if err != nil {
		return
	}

	_, writeErr := f.WriteString(graph)
	f.Close()
	if writeErr != nil {
		cmd := exec.Command("mmdc", "-i", graphFile, "-o", graphImgFile, "--scale", "4")
		// Generating is best effort
		_ = cmd.Start()
	}
}

// GoTransitionRequired transitions the state machine to a new state with associated state data 'data'
// If this transition is not handled by the client, it returns an error.
func (fsm *FSM) GoTransitionRequired(newState FSMStateID, data interface{}) error {
	oldState := fsm.Current
	if !fsm.GoTransitionWithData(newState, data) {
		return types.NewWrappedError("failed required transition", fmt.Errorf("required transition not handled, from: %s -> to: %s", fsm.GetStateName(oldState), fsm.GetStateName(newState)))
	}
	return nil
}

// GoTransitionWithData is a helper that transitions the state machine toward the 'newState' with associated state data 'data'
// It returns whether or not the transition is handled by the client.
func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool {
	ok := fsm.HasTransition(newState)

	handled := false
	if ok {
		oldState := fsm.Current
		fsm.Current = newState
		if fsm.Generate {
			fsm.writeGraph()
		}

		handled = fsm.StateCallback(oldState, newState, data)
	}

	return handled
}

// GoTransition is an alias to call GoTransitionWithData but have an empty string as data.
func (fsm *FSM) GoTransition(newState FSMStateID) bool {
	// No data means the callback is never required
	return fsm.GoTransitionWithData(newState, "")
}

// generateMermaidGraph generates a graph suitable to be converted by the mermaid.js tool
// it returns the graph as a string.
func (fsm *FSM) generateMermaidGraph() string {
	graph := "graph TD\n"
	sortedFSM := make(FSMStateIDSlice, 0, len(fsm.States))
	for stateID := range fsm.States {
		sortedFSM = append(sortedFSM, stateID)
	}
	sort.Sort(sortedFSM)
	for _, state := range sortedFSM {
		transitions := fsm.States[state].Transitions
		for _, transition := range transitions {
			if state == fsm.Current {
				graph += "\nstyle " + fsm.GetStateName(state) + " fill:cyan\n"
			} else {
				graph += "\nstyle " + fsm.GetStateName(state) + " fill:white\n"
			}
			graph += fsm.GetStateName(
				state,
			) + "(" + fsm.GetStateName(
				state,
			) + ") " + "-->|" + transition.Description + "| " + fsm.GetStateName(
				transition.To,
			) + "\n"
		}
	}
	return graph
}

// GenerateGraph generates a mermaid graph if the state machine is initialized
// If the graph cannot be generated, it returns the empty string.
func (fsm *FSM) GenerateGraph() string {
	if fsm.GetStateName != nil {
		return fsm.generateMermaidGraph()
	}

	return ""
}