aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthew Wild <mwild1@gmail.com>2022-03-17 17:45:27 +0000
committerMatthew Wild <mwild1@gmail.com>2022-03-17 17:45:27 +0000
commitb06dd038cd796b93f9fa7a578ba0f279feb2f2b0 (patch)
treeab2e3947dff6c5ca33db458dacb4e019fffc94b3
parentaa0d6f297d27f1c3ae1a94d05d9855af6dda4e45 (diff)
downloadprosody-b06dd038cd796b93f9fa7a578ba0f279feb2f2b0.tar.gz
prosody-b06dd038cd796b93f9fa7a578ba0f279feb2f2b0.zip
util.fsm: New utility lib for finite state machines
-rw-r--r--spec/util_fsm_spec.lua250
-rw-r--r--util/fsm.lua154
2 files changed, 404 insertions, 0 deletions
diff --git a/spec/util_fsm_spec.lua b/spec/util_fsm_spec.lua
new file mode 100644
index 00000000..9aff7ca6
--- /dev/null
+++ b/spec/util_fsm_spec.lua
@@ -0,0 +1,250 @@
+describe("util.fsm", function ()
+ local new_fsm = require "util.fsm".new;
+
+ do
+ local fsm = new_fsm({
+ transitions = {
+ { name = "melt", from = "solid", to = "liquid" };
+ { name = "freeze", from = "liquid", to = "solid" };
+ };
+ });
+
+ it("works", function ()
+ local water = fsm:init("liquid");
+ water:freeze();
+ assert.equal("solid", water.state);
+ water:melt();
+ assert.equal("liquid", water.state);
+ end);
+
+ it("does not allow invalid transitions", function ()
+ local water = fsm:init("liquid");
+ assert.has_errors(function ()
+ water:melt();
+ end, "Invalid state transition: liquid cannot melt");
+
+ water:freeze();
+ assert.equal("solid", water.state);
+
+ water:melt();
+ assert.equal("liquid", water.state);
+
+ assert.has_errors(function ()
+ water:melt();
+ end, "Invalid state transition: liquid cannot melt");
+ end);
+ end
+
+ it("notifies observers", function ()
+ local n = 0;
+ local has_become_solid = spy.new(function (transition)
+ assert.is_table(transition);
+ assert.equal("solid", transition.to);
+ assert.is_not_nil(transition.instance);
+ n = n + 1;
+ if n == 1 then
+ assert.is_nil(transition.from);
+ assert.is_nil(transition.from_attr);
+ elseif n == 2 then
+ assert.equal("liquid", transition.from);
+ assert.is_nil(transition.from_attr);
+ assert.equal("freeze", transition.name);
+ end
+ end);
+ local is_melting = spy.new(function (transition)
+ assert.is_table(transition);
+ assert.equal("melt", transition.name);
+ assert.is_not_nil(transition.instance);
+ end);
+ local fsm = new_fsm({
+ transitions = {
+ { name = "melt", from = "solid", to = "liquid" };
+ { name = "freeze", from = "liquid", to = "solid" };
+ };
+ state_handlers = {
+ solid = has_become_solid;
+ };
+
+ transition_handlers = {
+ melt = is_melting;
+ };
+ });
+
+ local water = fsm:init("liquid");
+ assert.spy(has_become_solid).was_not_called();
+
+ local ice = fsm:init("solid"); --luacheck: ignore 211/ice
+ assert.spy(has_become_solid).was_called(1);
+
+ water:freeze();
+
+ assert.spy(is_melting).was_not_called();
+ water:melt();
+ assert.spy(is_melting).was_called(1);
+ end);
+
+ local function test_machine(fsm_spec, expected_transitions, test_func)
+ fsm_spec.handlers = fsm_spec.handlers or {};
+ fsm_spec.handlers.transitioned = function (transition)
+ local expected_transition = table.remove(expected_transitions, 1);
+ assert.same(expected_transition, {
+ name = transition.name;
+ to = transition.to;
+ to_attr = transition.to_attr;
+ from = transition.from;
+ from_attr = transition.from_attr;
+ });
+ end;
+ local fsm = new_fsm(fsm_spec);
+ test_func(fsm);
+ assert.equal(0, #expected_transitions);
+ end
+
+
+ it("handles transitions with the same name", function ()
+ local expected_transitions = {
+ { name = nil , from = "none", to = "A" };
+ { name = "step", from = "A", to = "B" };
+ { name = "step", from = "B", to = "C" };
+ { name = "step", from = "C", to = "D" };
+ };
+
+ test_machine({
+ default_state = "none";
+ transitions = {
+ { name = "step", from = "A", to = "B" };
+ { name = "step", from = "B", to = "C" };
+ { name = "step", from = "C", to = "D" };
+ };
+ }, expected_transitions, function (fsm)
+ local i = fsm:init("A");
+ i:step(); -- B
+ i:step(); -- C
+ i:step(); -- D
+ assert.has_errors(function ()
+ i:step();
+ end, "Invalid state transition: D cannot step");
+ end);
+ end);
+
+ it("handles supports wildcard transitions", function ()
+ local expected_transitions = {
+ { name = nil , from = "none", to = "A" };
+ { name = "step", from = "A", to = "B" };
+ { name = "step", from = "B", to = "C" };
+ { name = "reset", from = "C", to = "A" };
+ { name = "step", from = "A", to = "B" };
+ { name = "step", from = "B", to = "C" };
+ { name = "step", from = "C", to = "D" };
+ };
+
+ test_machine({
+ default_state = "none";
+ transitions = {
+ { name = "step", from = "A", to = "B" };
+ { name = "step", from = "B", to = "C" };
+ { name = "step", from = "C", to = "D" };
+ { name = "reset", from = "*", to = "A" };
+ };
+ }, expected_transitions, function (fsm)
+ local i = fsm:init("A");
+ i:step(); -- B
+ i:step(); -- C
+ i:reset(); -- A
+ i:step(); -- B
+ i:step(); -- C
+ i:step(); -- D
+ assert.has_errors(function ()
+ i:step();
+ end, "Invalid state transition: D cannot step");
+ end);
+ end);
+
+ it("supports specifying multiple from states", function ()
+ local expected_transitions = {
+ { name = nil , from = "none", to = "A" };
+ { name = "step", from = "A", to = "B" };
+ { name = "step", from = "B", to = "C" };
+ { name = "reset", from = "C", to = "A" };
+ { name = "step", from = "A", to = "B" };
+ { name = "step", from = "B", to = "C" };
+ { name = "step", from = "C", to = "D" };
+ };
+
+ test_machine({
+ default_state = "none";
+ transitions = {
+ { name = "step", from = "A", to = "B" };
+ { name = "step", from = "B", to = "C" };
+ { name = "step", from = "C", to = "D" };
+ { name = "reset", from = {"B", "C", "D"}, to = "A" };
+ };
+ }, expected_transitions, function (fsm)
+ local i = fsm:init("A");
+ i:step(); -- B
+ i:step(); -- C
+ i:reset(); -- A
+ assert.has_errors(function ()
+ i:reset();
+ end, "Invalid state transition: A cannot reset");
+ i:step(); -- B
+ i:step(); -- C
+ i:step(); -- D
+ assert.has_errors(function ()
+ i:step();
+ end, "Invalid state transition: D cannot step");
+ end);
+ end);
+
+ it("handles transitions with the same start/end state", function ()
+ local expected_transitions = {
+ { name = nil , from = "none", to = "A" };
+ { name = "step", from = "A", to = "B" };
+ { name = "step", from = "B", to = "B" };
+ { name = "step", from = "B", to = "B" };
+ };
+
+ test_machine({
+ default_state = "none";
+ transitions = {
+ { name = "step", from = "A", to = "B" };
+ { name = "step", from = "B", to = "B" };
+ };
+ }, expected_transitions, function (fsm)
+ local i = fsm:init("A");
+ i:step(); -- B
+ i:step(); -- B
+ i:step(); -- B
+ end);
+ end);
+
+ it("can identify instances of a specific fsm", function ()
+ local fsm1 = new_fsm({ default_state = "a" });
+ local fsm2 = new_fsm({ default_state = "a" });
+
+ local i1 = fsm1:init();
+ local i2 = fsm2:init();
+
+ assert.truthy(fsm1:is_instance(i1));
+ assert.truthy(fsm2:is_instance(i2));
+
+ assert.falsy(fsm1:is_instance(i2));
+ assert.falsy(fsm2:is_instance(i1));
+ end);
+
+ it("errors when an invalid initial state is passed", function ()
+ local fsm1 = new_fsm({
+ transitions = {
+ { name = "", from = "A", to = "B" };
+ };
+ });
+
+ assert.has_no_errors(function ()
+ fsm1:init("A");
+ end);
+
+ assert.has_errors(function ()
+ fsm1:init("C");
+ end);
+ end);
+end);
diff --git a/util/fsm.lua b/util/fsm.lua
new file mode 100644
index 00000000..94a543d1
--- /dev/null
+++ b/util/fsm.lua
@@ -0,0 +1,154 @@
+local events = require "util.events";
+
+local fsm_methods = {};
+local fsm_mt = { __index = fsm_methods };
+
+local function is_fsm(o)
+ local mt = getmetatable(o);
+ return mt == fsm_mt;
+end
+
+local function notify_transition(fire_event, transition_event)
+ local ret;
+ ret = fire_event("transition", transition_event);
+ if ret ~= nil then return ret; end
+ if transition_event.from ~= transition_event.to then
+ ret = fire_event("leave/"..transition_event.from, transition_event);
+ if ret ~= nil then return ret; end
+ end
+ ret = fire_event("transition/"..transition_event.name, transition_event);
+ if ret ~= nil then return ret; end
+end
+
+local function notify_transitioned(fire_event, transition_event)
+ if transition_event.to ~= transition_event.from then
+ fire_event("enter/"..transition_event.to, transition_event);
+ end
+ if transition_event.name then
+ fire_event("transitioned/"..transition_event.name, transition_event);
+ end
+ fire_event("transitioned", transition_event);
+end
+
+local function do_transition(name)
+ return function (self, attr)
+ local new_state = self.fsm.states[self.state][name] or self.fsm.states["*"][name];
+ if not new_state then
+ return error(("Invalid state transition: %s cannot %s"):format(self.state, name));
+ end
+
+ local transition_event = {
+ instance = self;
+
+ name = name;
+ to = new_state;
+ to_attr = attr;
+
+ from = self.state;
+ from_attr = self.state_attr;
+ };
+
+ local fire_event = self.fsm.events.fire_event;
+ local ret = notify_transition(fire_event, transition_event);
+ if ret ~= nil then return nil, ret; end
+
+ self.state = new_state;
+ self.state_attr = attr;
+
+ notify_transitioned(fire_event, transition_event);
+ return true;
+ end;
+end
+
+local function new(desc)
+ local self = setmetatable({
+ default_state = desc.default_state;
+ events = events.new();
+ }, fsm_mt);
+
+ -- states[state_name][transition_name] = new_state_name
+ local states = { ["*"] = {} };
+ if desc.default_state then
+ states[desc.default_state] = {};
+ end
+ self.states = states;
+
+ local instance_methods = {};
+ self._instance_mt = { __index = instance_methods };
+
+ for _, transition in ipairs(desc.transitions or {}) do
+ local from_states = transition.from;
+ if type(from_states) ~= "table" then
+ from_states = { from_states };
+ end
+ for _, from in ipairs(from_states) do
+ if not states[from] then
+ states[from] = {};
+ end
+ if not states[transition.to] then
+ states[transition.to] = {};
+ end
+ if states[from][transition.name] then
+ return error(("Duplicate transition in FSM specification: %s from %s"):format(transition.name, from));
+ end
+ states[from][transition.name] = transition.to;
+ end
+
+ -- Add public method to trigger this transition
+ instance_methods[transition.name] = do_transition(transition.name);
+ end
+
+ if desc.state_handlers then
+ for state_name, handler in pairs(desc.state_handlers) do
+ self.events.add_handler("enter/"..state_name, handler);
+ end
+ end
+
+ if desc.transition_handlers then
+ for transition_name, handler in pairs(desc.transition_handlers) do
+ self.events.add_handler("transition/"..transition_name, handler);
+ end
+ end
+
+ if desc.handlers then
+ self.events.add_handlers(desc.handlers);
+ end
+
+ return self;
+end
+
+function fsm_methods:init(state_name, state_attr)
+ local initial_state = assert(state_name or self.default_state, "no initial state specified");
+ if not self.states[initial_state] then
+ return error("Invalid initial state: "..initial_state);
+ end
+ local instance = setmetatable({
+ fsm = self;
+ state = initial_state;
+ state_attr = state_attr;
+ }, self._instance_mt);
+
+ if initial_state ~= self.default_state then
+ local fire_event = self.events.fire_event;
+ notify_transitioned(fire_event, {
+ instance = instance;
+
+ to = initial_state;
+ to_attr = state_attr;
+
+ from = self.default_state;
+ });
+ end
+
+ return instance;
+end
+
+function fsm_methods:is_instance(o)
+ local mt = getmetatable(o);
+ return mt == self._instance_mt;
+end
+
+return {
+ new = new;
+ is_fsm = is_fsm;
+};