aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthew Wild <mwild1@gmail.com>2018-10-10 17:45:19 +0100
committerMatthew Wild <mwild1@gmail.com>2018-10-10 17:45:19 +0100
commite557cfb0123ef1e7fe366f25e38bbbccaa36ff4c (patch)
tree3c8c24f415bb667f80bca10e2428f5990d187d8d
parent09b659294a5762c559e38b00c062765657585471 (diff)
downloadprosody-e557cfb0123ef1e7fe366f25e38bbbccaa36ff4c.tar.gz
prosody-e557cfb0123ef1e7fe366f25e38bbbccaa36ff4c.zip
util.promise: ES6-like API for promises
-rw-r--r--spec/util_promise_spec.lua262
-rw-r--r--util/promise.lua133
2 files changed, 395 insertions, 0 deletions
diff --git a/spec/util_promise_spec.lua b/spec/util_promise_spec.lua
new file mode 100644
index 00000000..f0aec64c
--- /dev/null
+++ b/spec/util_promise_spec.lua
@@ -0,0 +1,262 @@
+local promise = require "util.promise";
+
+describe("util.promise", function ()
+ describe("new()", function ()
+ it("returns a promise object", function ()
+ assert(promise.new());
+ end);
+ end);
+ it("notifies immediately for fulfilled promises", function ()
+ local p = promise.new(function (resolve)
+ resolve("foo");
+ end);
+ local cb = spy.new(function (v)
+ assert.equal("foo", v);
+ end);
+ p:next(cb);
+ assert.spy(cb).was_called(1);
+ end);
+ it("notifies on fulfilment of pending promises", function ()
+ local r;
+ local p = promise.new(function (resolve)
+ r = resolve;
+ end);
+ local cb = spy.new(function (v)
+ assert.equal("foo", v);
+ end);
+ p:next(cb);
+ assert.spy(cb).was_called(0);
+ r("foo");
+ assert.spy(cb).was_called(1);
+ end);
+ it("allows chaining :next() calls", function ()
+ local r;
+ local result;
+ local p = promise.new(function (resolve)
+ r = resolve;
+ end);
+ local cb1 = spy.new(function (v)
+ assert.equal("foo", v);
+ return "bar";
+ end);
+ local cb2 = spy.new(function (v)
+ assert.equal("bar", v);
+ result = v;
+ end);
+ p:next(cb1):next(cb2);
+ assert.spy(cb1).was_called(0);
+ assert.spy(cb2).was_called(0);
+ r("foo");
+ assert.spy(cb1).was_called(1);
+ assert.spy(cb2).was_called(1);
+ assert.equal("bar", result);
+ end);
+ it("supports multiple :next() calls on the same promise", function ()
+ local r;
+ local result;
+ local p = promise.new(function (resolve)
+ r = resolve;
+ end);
+ local cb1 = spy.new(function (v)
+ assert.equal("foo", v);
+ result = v;
+ end);
+ local cb2 = spy.new(function (v)
+ assert.equal("foo", v);
+ result = v;
+ end);
+ p:next(cb1);
+ p:next(cb2);
+ assert.spy(cb1).was_called(0);
+ assert.spy(cb2).was_called(0);
+ r("foo");
+ assert.spy(cb1).was_called(1);
+ assert.spy(cb2).was_called(1);
+ assert.equal("foo", result);
+ end);
+ it("automatically rejects on error", function ()
+ local r;
+ local p = promise.new(function (resolve)
+ r = resolve;
+ error("oh no");
+ end);
+ local cb = spy.new(function () end);
+ local err_cb = spy.new(function (v)
+ assert.equal("oh no", v);
+ end);
+ p:next(cb, err_cb);
+ assert.spy(cb).was_called(0);
+ assert.spy(err_cb).was_called(1);
+ r("foo");
+ assert.spy(cb).was_called(0);
+ assert.spy(err_cb).was_called(1);
+ end);
+ it("supports reject()", function ()
+ local r, result;
+ local p = promise.new(function (resolve, reject)
+ r = reject;
+ end);
+ local cb = spy.new(function () end);
+ local err_cb = spy.new(function (v)
+ result = v;
+ assert.equal("oh doh", v);
+ end);
+ p:next(cb, err_cb);
+ assert.spy(cb).was_called(0);
+ assert.spy(err_cb).was_called(0);
+ r("oh doh");
+ assert.spy(cb).was_called(0);
+ assert.spy(err_cb).was_called(1);
+ assert.equal("oh doh", result);
+ end);
+ it("supports chaining of rejected promises", function ()
+ local r, result;
+ local p = promise.new(function (resolve, reject)
+ r = reject;
+ end);
+ local cb = spy.new(function () end);
+ local err_cb = spy.new(function (v)
+ result = v;
+ assert.equal("oh doh", v);
+ return "ok"
+ end);
+ local cb2 = spy.new(function (v)
+ result = v;
+ end);
+ local err_cb2 = spy.new(function (v) end);
+ p:next(cb, err_cb):next(cb2, err_cb2)
+ assert.spy(cb).was_called(0);
+ assert.spy(err_cb).was_called(0);
+ assert.spy(cb2).was_called(0);
+ assert.spy(err_cb2).was_called(0);
+ r("oh doh");
+ assert.spy(cb).was_called(0);
+ assert.spy(err_cb).was_called(1);
+ assert.spy(cb2).was_called(1);
+ assert.spy(err_cb2).was_called(0);
+ assert.equal("ok", result);
+ end);
+
+ describe("race()", function ()
+ it("works with fulfilled promises", function ()
+ local p1, p2 = promise.resolve("yep"), promise.resolve("nope");
+ local p = promise.race({ p1, p2 });
+ local result;
+ p:next(function (v)
+ result = v;
+ end);
+ assert.equal("yep", result);
+ end);
+ it("works with pending promises", function ()
+ local r1, r2;
+ local p1, p2 = promise.new(function (resolve) r1 = resolve end), promise.new(function (resolve) r2 = resolve end);
+ local p = promise.race({ p1, p2 });
+
+ local result;
+ local cb = spy.new(function (v)
+ result = v;
+ end);
+ p:next(cb);
+ assert.spy(cb).was_called(0);
+ r2("yep");
+ r1("nope");
+ assert.spy(cb).was_called(1);
+ assert.equal("yep", result);
+ end);
+ end);
+ describe("all()", function ()
+ it("works with fulfilled promises", function ()
+ local p1, p2 = promise.resolve("yep"), promise.resolve("nope");
+ local p = promise.all({ p1, p2 });
+ local result;
+ p:next(function (v)
+ result = v;
+ end);
+ assert.same({ "yep", "nope" }, result);
+ end);
+ it("works with pending promises", function ()
+ local r1, r2;
+ local p1, p2 = promise.new(function (resolve) r1 = resolve end), promise.new(function (resolve) r2 = resolve end);
+ local p = promise.all({ p1, p2 });
+
+ local result;
+ local cb = spy.new(function (v)
+ result = v;
+ end);
+ p:next(cb);
+ assert.spy(cb).was_called(0);
+ r2("yep");
+ assert.spy(cb).was_called(0);
+ r1("nope");
+ assert.spy(cb).was_called(1);
+ assert.same({ "nope", "yep" }, result);
+ end);
+ it("rejects if any promise rejects", function ()
+ local r1, r2;
+ local p1 = promise.new(function (resolve, reject) r1 = reject end);
+ local p2 = promise.new(function (resolve, reject) r2 = reject end);
+ local p = promise.all({ p1, p2 });
+
+ local result;
+ local cb = spy.new(function (v)
+ result = v;
+ end);
+ local cb_err = spy.new(function (v)
+ result = v;
+ end);
+ p:next(cb, cb_err);
+ assert.spy(cb).was_called(0);
+ assert.spy(cb_err).was_called(0);
+ r2("fail");
+ assert.spy(cb).was_called(0);
+ assert.spy(cb_err).was_called(1);
+ r1("nope");
+ assert.spy(cb).was_called(0);
+ assert.spy(cb_err).was_called(1);
+ assert.equal("fail", result);
+ end);
+ end);
+ describe("catch()", function ()
+ it("works", function ()
+ local result;
+ local p = promise.new(function (resolve)
+ error({ foo = true });
+ end);
+ local cb1 = spy.new(function (v)
+ result = v;
+ end);
+ assert.spy(cb1).was_called(0);
+ p:catch(cb1);
+ assert.spy(cb1).was_called(1);
+ assert.same({ foo = true }, result);
+ end);
+ end);
+ it("promises may be resolved by other promises", function ()
+ local r1, r2;
+ local p1, p2 = promise.new(function (resolve) r1 = resolve end), promise.new(function (resolve) r2 = resolve end);
+
+ local result;
+ local cb = spy.new(function (v)
+ result = v;
+ end);
+ p1:next(cb);
+ assert.spy(cb).was_called(0);
+
+ r1(p2);
+ assert.spy(cb).was_called(0);
+ r2("yep");
+ assert.spy(cb).was_called(1);
+ assert.equal("yep", result);
+ end);
+ describe("reject()", function ()
+ it("returns a rejected promise", function ()
+ local p = promise.reject("foo");
+ local cb = spy.new(function (v)
+ result = v;
+ end);
+ p:next(cb);
+ assert.spy(cb).was_called(1);
+ assert.spy(cb).was_called_with("foo");
+ end);
+ end);
+end);
diff --git a/util/promise.lua b/util/promise.lua
new file mode 100644
index 00000000..7184f5fb
--- /dev/null
+++ b/util/promise.lua
@@ -0,0 +1,133 @@
+local promise_methods = {};
+local promise_mt = { __name = "promise", __index = promise_methods };
+
+local function is_promise(o)
+ local mt = getmetatable(o);
+ return mt == promise_mt;
+end
+
+local function next_pending(self, on_fulfilled, on_rejected)
+ table.insert(self._pending_on_fulfilled, on_fulfilled);
+ table.insert(self._pending_on_rejected, on_rejected);
+end
+
+local function next_fulfilled(promise, on_fulfilled, on_rejected) -- luacheck: ignore 212/on_rejected
+ on_fulfilled(promise.value);
+end
+
+local function next_rejected(promise, on_fulfilled, on_rejected) -- luacheck: ignore 212/on_fulfilled
+ on_rejected(promise.reason);
+end
+
+local function promise_settle(promise, new_state, new_next, cbs, value)
+ if promise._state ~= "pending" then
+ return;
+ end
+ promise._state = new_state;
+ promise._next = new_next;
+ for _, cb in ipairs(cbs) do
+ cb(value);
+ end
+ return true;
+end
+
+local function new_resolve_functions(p)
+ local resolved = false;
+ local function _resolve(v)
+ if resolved then return; end
+ resolved = true;
+ if is_promise(v) then
+ v:next(new_resolve_functions(p));
+ elseif promise_settle(p, "fulfilled", next_fulfilled, p._pending_on_fulfilled, v) then
+ p.value = v;
+ end
+
+ end
+ local function _reject(e)
+ if resolved then return; end
+ resolved = true;
+ if promise_settle(p, "rejected", next_rejected, p._pending_on_rejected, e) then
+ p.reason = e;
+ end
+ end
+ return _resolve, _reject;
+end
+
+local function new(f)
+ local p = setmetatable({ _state = "pending", _next = next_pending, _pending_on_fulfilled = {}, _pending_on_rejected = {} }, promise_mt);
+ if f then
+ local resolve, reject = new_resolve_functions(p);
+ local ok, ret = pcall(f, resolve, reject);
+ if not ok and p._state == "pending" then
+ reject(ret);
+ end
+ end
+ return p;
+end
+
+local function wrap_handler(f, resolve, reject)
+ return function (param)
+ local ok, ret = pcall(f, param);
+ if ok then
+ resolve(ret);
+ else
+ reject(ret);
+ end
+ end;
+end
+
+function promise_methods:next(on_fulfilled, on_rejected)
+ return new(function (resolve, reject)
+ self:_next(
+ on_fulfilled and wrap_handler(on_fulfilled, resolve, reject) or nil,
+ on_rejected and wrap_handler(on_rejected, resolve, reject) or nil
+ );
+ end);
+end
+
+function promise_methods:catch(on_rejected)
+ return self:next(nil, on_rejected);
+end
+
+local function all(promises)
+ return new(function (resolve, reject)
+ local count, total, results = 0, #promises, {};
+ for i = 1, total do
+ promises[i]:next(function (v)
+ results[i] = v;
+ count = count + 1;
+ if count == total then
+ resolve(results);
+ end
+ end, reject);
+ end
+ end);
+end
+
+local function race(promises)
+ return new(function (resolve, reject)
+ for i = 1, #promises do
+ promises[i]:next(resolve, reject);
+ end
+ end);
+end
+
+local function resolve(v)
+ return new(function (_resolve)
+ _resolve(v);
+ end);
+end
+
+local function reject(v)
+ return new(function (_reject)
+ _reject(v);
+ end);
+end
+
+return {
+ new = new;
+ resolve = resolve;
+ reject = reject;
+ all = all;
+ race = race;
+}