summaryrefslogtreecommitdiff
path: root/internal/fsm/fsm.go
blob: b8fd644129c3f7800e3032b0de0e3ea96f708c1d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
// Package fsm defines a finite state machine and has the ability to save this state machine to a graph file
// This graph file can be visualized using mermaid.js
package fsm

import (
	"fmt"
	"os"
	"os/exec"
	"path"
	"sort"
	"github.com/eduvpn/eduvpn-common/types"
)

type (
	//StateID represents the Identifier of the state
	FSMStateID      int8
	//StateIDSlice represents the list of state identifiers
	FSMStateIDSlice []FSMStateID
)

func (v FSMStateIDSlice) Len() int {
	return len(v)
}

func (v FSMStateIDSlice) Less(i, j int) bool {
	return v[i] < v[j]
}

func (v FSMStateIDSlice) Swap(i, j int) {
	v[i], v[j] = v[j], v[i]
}

// Transition indicates an arrow in the state graph
type FSMTransition struct {
	// To represents the to-be-new state
	To          FSMStateID
	// Description is what type of message the arrow gets in the graph
	Description string
}

type (
	FSMStates map[FSMStateID]FSMState
)

// State represents a single node in the graph
type FSMState struct {
	// Transitions indicates which out arrows this node has
	Transitions []FSMTransition
}

// FSM represents the total graph
type FSM struct {
	// States is the map from state ID to states
	States  FSMStates

	// Current is the current state represented by the identifier
	Current FSMStateID

	// Name represents the descriptive name of this state machine
	Name          string

	// StateCallback is the function ran when a transition occurs
	// It takes the old state, the new state and the data and returns if this is handled by the client
	StateCallback func(FSMStateID, FSMStateID, interface{}) bool

	// Directory represents the path where the state graph is stored
	Directory     string

	// Generate represents whether we want to generate the graph
	Generate         bool

	// GetStateName gets the name of a state as a string
	GetStateName       func(FSMStateID) string
}

// Init initializes the state machine and sets it to the given current state
func (fsm *FSM) Init(
	current FSMStateID,
	states map[FSMStateID]FSMState,
	callback func(FSMStateID, FSMStateID, interface{}) bool,
	directory string,
	nameGen func(FSMStateID) string,
	generate bool,
) {
	fsm.States = states
	fsm.Current = current
	fsm.StateCallback = callback
	fsm.Directory = directory
	fsm.GetStateName = nameGen
	fsm.Generate = generate
}

// InState returns whether or not the state machine is in the given 'check' state
func (fsm *FSM) InState(check FSMStateID) bool {
	return check == fsm.Current
}

// HasTransition checks whether or not the state machine has a transition to the given 'check' state
func (fsm *FSM) HasTransition(check FSMStateID) bool {
	for _, transitionState := range fsm.States[fsm.Current].Transitions {
		if transitionState.To == check {
			return true
		}
	}

	return false
}

// getGraphFilename gets the full path to the graph filename including the .graph extension
func (fsm *FSM) getGraphFilename(extension string) string {
	debugPath := path.Join(fsm.Directory, "graph")
	return fmt.Sprintf("%s%s", debugPath, extension)
}

// writeGraph writes the state machine to a .graph file
func (fsm *FSM) writeGraph() {
	graph := fsm.GenerateGraph()
	graphFile := fsm.getGraphFilename(".graph")
	graphImgFile := fsm.getGraphFilename(".png")
	f, err := os.Create(graphFile)
	if err != nil {
		return
	}

	_, writeErr := f.WriteString(graph)
	f.Close()
	if writeErr != nil {
		cmd := exec.Command("mmdc", "-i", graphFile, "-o", graphImgFile, "--scale", "4")
		// Generating is best effort
		_ = cmd.Start()
	}
}

// GoTransitionRequired transitions the state machine to a new state with associated state data 'data'
// If this transition is not handled by the client, it returns an error
func (fsm *FSM) GoTransitionRequired(newState FSMStateID, data interface{}) error {
	oldState := fsm.Current
	if !fsm.GoTransitionWithData(newState, data) {
		return types.NewWrappedError("failed required transition", fmt.Errorf("required transition not handled, from: %s -> to: %s", fsm.GetStateName(oldState), fsm.GetStateName(newState)))
	}
	return nil
}

// GoTransitionWithData is a helper that transitions the state machine toward the 'newState' with associated state data 'data'
// It returns whether or not the transition is handled by the client
func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data interface{}) bool {
	ok := fsm.HasTransition(newState)

	handled := false
	if ok {
		oldState := fsm.Current
		fsm.Current = newState
		if fsm.Generate {
			fsm.writeGraph()
		}

		handled = fsm.StateCallback(oldState, newState, data)
	}

	return handled
}

// GoTransition is an alias to call GoTransitionWithData but have an empty string as data
func (fsm *FSM) GoTransition(newState FSMStateID) bool {
	// No data means the callback is never required
	return fsm.GoTransitionWithData(newState, "")
}

// generateMermaidGraph generates a graph suitable to be converted by the mermaid.js tool
// it returns the graph as a string
func (fsm *FSM) generateMermaidGraph() string {
	graph := "graph TD\n"
	sortedFSM := make(FSMStateIDSlice, 0, len(fsm.States))
	for stateID := range fsm.States {
		sortedFSM = append(sortedFSM, stateID)
	}
	sort.Sort(sortedFSM)
	for _, state := range sortedFSM {
		transitions := fsm.States[state].Transitions
		for _, transition := range transitions {
			if state == fsm.Current {
				graph += "\nstyle " + fsm.GetStateName(state) + " fill:cyan\n"
			} else {
				graph += "\nstyle " + fsm.GetStateName(state) + " fill:white\n"
			}
			graph += fsm.GetStateName(
				state,
			) + "(" + fsm.GetStateName(
				state,
			) + ") " + "-->|" + transition.Description + "| " + fsm.GetStateName(
				transition.To,
			) + "\n"
		}
	}
	return graph
}

// GenerateGraph generates a mermaid graph if the state machine is initialized
// If the graph cannot be generated, it returns the empty string
func (fsm *FSM) GenerateGraph() string {
	if fsm.GetStateName != nil {
		return fsm.generateMermaidGraph()
	}

	return ""
}