aboutsummaryrefslogtreecommitdiffstats
path: root/util/async.lua
blob: 968ec80495489b51b66a8e1e1ba6f8c201b7e14d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
local log = require "util.logger".init("util.async");

local function runner_continue(thread)
	-- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure)
	if coroutine.status(thread) ~= "suspended" then -- This should suffice
		return false;
	end
	local ok, state, runner = coroutine.resume(thread);
	if not ok then
		local level = 0;
		while debug.getinfo(thread, level, "") do level = level + 1; end
		ok, runner = debug.getlocal(thread, level-1, 1);
		local error_handler = runner.watchers.error;
		if error_handler then error_handler(runner, debug.traceback(thread, state)); end
	elseif state == "ready" then
		-- If state is 'ready', it is our responsibility to update runner.state from 'waiting'.
		-- We also have to :run(), because the queue might have further items that will not be
		-- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer).
		runner.state = "ready";
		runner:run();
	end
	return true;
end

local function waiter(num)
	local thread = coroutine.running();
	if not thread then
		error("Not running in an async context, see http://prosody.im/doc/developers/async");
	end
	num = num or 1;
	local waiting;
	return function ()
		if num == 0 then return; end -- already done
		waiting = true;
		coroutine.yield("wait");
	end, function ()
		num = num - 1;
		if num == 0 and waiting then
			runner_continue(thread);
		elseif num < 0 then
			error("done() called too many times");
		end
	end;
end

local function guarder()
	local guards = {};
	return function (id, func)
		local thread = coroutine.running();
		if not thread then
			error("Not running in an async context, see http://prosody.im/doc/developers/async");
		end
		local guard = guards[id];
		if not guard then
			guard = {};
			guards[id] = guard;
			log("debug", "New guard!");
		else
			table.insert(guard, thread);
			log("debug", "Guarded. %d threads waiting.", #guard)
			coroutine.yield("wait");
		end
		local function exit()
			local next_waiting = table.remove(guard, 1);
			if next_waiting then
				log("debug", "guard: Executing next waiting thread (%d left)", #guard)
				runner_continue(next_waiting);
			else
				log("debug", "Guard off duty.")
				guards[id] = nil;
			end
		end
		if func then
			func();
			exit();
			return;
		end
		return exit;
	end;
end

local runner_mt = {};
runner_mt.__index = runner_mt;

local function runner_create_thread(func, self)
	local thread = coroutine.create(function (self)
		while true do
			func(coroutine.yield("ready", self));
		end
	end);
	assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input
	return thread;
end

local empty_watchers = {};
local function runner(func, watchers, data)
	return setmetatable({ func = func, thread = false, state = "ready", notified_state = "ready",
		queue = {}, watchers = watchers or empty_watchers, data = data }
	, runner_mt);
end

function runner_mt:run(input)
	if input ~= nil then
		table.insert(self.queue, input);
	end
	if self.state ~= "ready" then
		return true, self.state, #self.queue;
	end

	local q, thread = self.queue, self.thread;
	if not thread or coroutine.status(thread) == "dead" then
		thread = runner_create_thread(self.func, self);
		self.thread = thread;
	end

	local n, state, err = #q, self.state, nil;
	self.state = "running";
	while n > 0 and state == "ready" do
		local consumed;
		for i = 1,n do
			local input = q[i];
			local ok, new_state = coroutine.resume(thread, input);
			if not ok then
				consumed, state, err = i, "ready", debug.traceback(thread, new_state);
				self.thread = nil;
				break;
			elseif new_state == "wait" then
				consumed, state = i, "waiting";
				break;
			end
		end
		if not consumed then consumed = n; end
		if q[n+1] ~= nil then
			n = #q;
		end
		for i = 1, n do
			q[i] = q[consumed+i];
		end
		n = #q;
	end
	self.state = state;
	if err or state ~= self.notified_state then
		if err then
			state = "error"
		else
			self.notified_state = state;
		end
		local handler = self.watchers[state];
		if handler then handler(self, err); end
	end
	return true, state, n;
end

function runner_mt:enqueue(input)
	table.insert(self.queue, input);
end

return { waiter = waiter, guarder = guarder, runner = runner };