diff options
author | Matthew Wild <mwild1@gmail.com> | 2018-10-10 17:45:19 +0100 |
---|---|---|
committer | Matthew Wild <mwild1@gmail.com> | 2018-10-10 17:45:19 +0100 |
commit | e557cfb0123ef1e7fe366f25e38bbbccaa36ff4c (patch) | |
tree | 3c8c24f415bb667f80bca10e2428f5990d187d8d | |
parent | 09b659294a5762c559e38b00c062765657585471 (diff) | |
download | prosody-e557cfb0123ef1e7fe366f25e38bbbccaa36ff4c.tar.gz prosody-e557cfb0123ef1e7fe366f25e38bbbccaa36ff4c.zip |
util.promise: ES6-like API for promises
-rw-r--r-- | spec/util_promise_spec.lua | 262 | ||||
-rw-r--r-- | util/promise.lua | 133 |
2 files changed, 395 insertions, 0 deletions
diff --git a/spec/util_promise_spec.lua b/spec/util_promise_spec.lua new file mode 100644 index 00000000..f0aec64c --- /dev/null +++ b/spec/util_promise_spec.lua @@ -0,0 +1,262 @@ +local promise = require "util.promise"; + +describe("util.promise", function () + describe("new()", function () + it("returns a promise object", function () + assert(promise.new()); + end); + end); + it("notifies immediately for fulfilled promises", function () + local p = promise.new(function (resolve) + resolve("foo"); + end); + local cb = spy.new(function (v) + assert.equal("foo", v); + end); + p:next(cb); + assert.spy(cb).was_called(1); + end); + it("notifies on fulfilment of pending promises", function () + local r; + local p = promise.new(function (resolve) + r = resolve; + end); + local cb = spy.new(function (v) + assert.equal("foo", v); + end); + p:next(cb); + assert.spy(cb).was_called(0); + r("foo"); + assert.spy(cb).was_called(1); + end); + it("allows chaining :next() calls", function () + local r; + local result; + local p = promise.new(function (resolve) + r = resolve; + end); + local cb1 = spy.new(function (v) + assert.equal("foo", v); + return "bar"; + end); + local cb2 = spy.new(function (v) + assert.equal("bar", v); + result = v; + end); + p:next(cb1):next(cb2); + assert.spy(cb1).was_called(0); + assert.spy(cb2).was_called(0); + r("foo"); + assert.spy(cb1).was_called(1); + assert.spy(cb2).was_called(1); + assert.equal("bar", result); + end); + it("supports multiple :next() calls on the same promise", function () + local r; + local result; + local p = promise.new(function (resolve) + r = resolve; + end); + local cb1 = spy.new(function (v) + assert.equal("foo", v); + result = v; + end); + local cb2 = spy.new(function (v) + assert.equal("foo", v); + result = v; + end); + p:next(cb1); + p:next(cb2); + assert.spy(cb1).was_called(0); + assert.spy(cb2).was_called(0); + r("foo"); + assert.spy(cb1).was_called(1); + assert.spy(cb2).was_called(1); + assert.equal("foo", result); + end); + it("automatically rejects on error", function () + local r; + local p = promise.new(function (resolve) + r = resolve; + error("oh no"); + end); + local cb = spy.new(function () end); + local err_cb = spy.new(function (v) + assert.equal("oh no", v); + end); + p:next(cb, err_cb); + assert.spy(cb).was_called(0); + assert.spy(err_cb).was_called(1); + r("foo"); + assert.spy(cb).was_called(0); + assert.spy(err_cb).was_called(1); + end); + it("supports reject()", function () + local r, result; + local p = promise.new(function (resolve, reject) + r = reject; + end); + local cb = spy.new(function () end); + local err_cb = spy.new(function (v) + result = v; + assert.equal("oh doh", v); + end); + p:next(cb, err_cb); + assert.spy(cb).was_called(0); + assert.spy(err_cb).was_called(0); + r("oh doh"); + assert.spy(cb).was_called(0); + assert.spy(err_cb).was_called(1); + assert.equal("oh doh", result); + end); + it("supports chaining of rejected promises", function () + local r, result; + local p = promise.new(function (resolve, reject) + r = reject; + end); + local cb = spy.new(function () end); + local err_cb = spy.new(function (v) + result = v; + assert.equal("oh doh", v); + return "ok" + end); + local cb2 = spy.new(function (v) + result = v; + end); + local err_cb2 = spy.new(function (v) end); + p:next(cb, err_cb):next(cb2, err_cb2) + assert.spy(cb).was_called(0); + assert.spy(err_cb).was_called(0); + assert.spy(cb2).was_called(0); + assert.spy(err_cb2).was_called(0); + r("oh doh"); + assert.spy(cb).was_called(0); + assert.spy(err_cb).was_called(1); + assert.spy(cb2).was_called(1); + assert.spy(err_cb2).was_called(0); + assert.equal("ok", result); + end); + + describe("race()", function () + it("works with fulfilled promises", function () + local p1, p2 = promise.resolve("yep"), promise.resolve("nope"); + local p = promise.race({ p1, p2 }); + local result; + p:next(function (v) + result = v; + end); + assert.equal("yep", result); + end); + it("works with pending promises", function () + local r1, r2; + local p1, p2 = promise.new(function (resolve) r1 = resolve end), promise.new(function (resolve) r2 = resolve end); + local p = promise.race({ p1, p2 }); + + local result; + local cb = spy.new(function (v) + result = v; + end); + p:next(cb); + assert.spy(cb).was_called(0); + r2("yep"); + r1("nope"); + assert.spy(cb).was_called(1); + assert.equal("yep", result); + end); + end); + describe("all()", function () + it("works with fulfilled promises", function () + local p1, p2 = promise.resolve("yep"), promise.resolve("nope"); + local p = promise.all({ p1, p2 }); + local result; + p:next(function (v) + result = v; + end); + assert.same({ "yep", "nope" }, result); + end); + it("works with pending promises", function () + local r1, r2; + local p1, p2 = promise.new(function (resolve) r1 = resolve end), promise.new(function (resolve) r2 = resolve end); + local p = promise.all({ p1, p2 }); + + local result; + local cb = spy.new(function (v) + result = v; + end); + p:next(cb); + assert.spy(cb).was_called(0); + r2("yep"); + assert.spy(cb).was_called(0); + r1("nope"); + assert.spy(cb).was_called(1); + assert.same({ "nope", "yep" }, result); + end); + it("rejects if any promise rejects", function () + local r1, r2; + local p1 = promise.new(function (resolve, reject) r1 = reject end); + local p2 = promise.new(function (resolve, reject) r2 = reject end); + local p = promise.all({ p1, p2 }); + + local result; + local cb = spy.new(function (v) + result = v; + end); + local cb_err = spy.new(function (v) + result = v; + end); + p:next(cb, cb_err); + assert.spy(cb).was_called(0); + assert.spy(cb_err).was_called(0); + r2("fail"); + assert.spy(cb).was_called(0); + assert.spy(cb_err).was_called(1); + r1("nope"); + assert.spy(cb).was_called(0); + assert.spy(cb_err).was_called(1); + assert.equal("fail", result); + end); + end); + describe("catch()", function () + it("works", function () + local result; + local p = promise.new(function (resolve) + error({ foo = true }); + end); + local cb1 = spy.new(function (v) + result = v; + end); + assert.spy(cb1).was_called(0); + p:catch(cb1); + assert.spy(cb1).was_called(1); + assert.same({ foo = true }, result); + end); + end); + it("promises may be resolved by other promises", function () + local r1, r2; + local p1, p2 = promise.new(function (resolve) r1 = resolve end), promise.new(function (resolve) r2 = resolve end); + + local result; + local cb = spy.new(function (v) + result = v; + end); + p1:next(cb); + assert.spy(cb).was_called(0); + + r1(p2); + assert.spy(cb).was_called(0); + r2("yep"); + assert.spy(cb).was_called(1); + assert.equal("yep", result); + end); + describe("reject()", function () + it("returns a rejected promise", function () + local p = promise.reject("foo"); + local cb = spy.new(function (v) + result = v; + end); + p:next(cb); + assert.spy(cb).was_called(1); + assert.spy(cb).was_called_with("foo"); + end); + end); +end); diff --git a/util/promise.lua b/util/promise.lua new file mode 100644 index 00000000..7184f5fb --- /dev/null +++ b/util/promise.lua @@ -0,0 +1,133 @@ +local promise_methods = {}; +local promise_mt = { __name = "promise", __index = promise_methods }; + +local function is_promise(o) + local mt = getmetatable(o); + return mt == promise_mt; +end + +local function next_pending(self, on_fulfilled, on_rejected) + table.insert(self._pending_on_fulfilled, on_fulfilled); + table.insert(self._pending_on_rejected, on_rejected); +end + +local function next_fulfilled(promise, on_fulfilled, on_rejected) -- luacheck: ignore 212/on_rejected + on_fulfilled(promise.value); +end + +local function next_rejected(promise, on_fulfilled, on_rejected) -- luacheck: ignore 212/on_fulfilled + on_rejected(promise.reason); +end + +local function promise_settle(promise, new_state, new_next, cbs, value) + if promise._state ~= "pending" then + return; + end + promise._state = new_state; + promise._next = new_next; + for _, cb in ipairs(cbs) do + cb(value); + end + return true; +end + +local function new_resolve_functions(p) + local resolved = false; + local function _resolve(v) + if resolved then return; end + resolved = true; + if is_promise(v) then + v:next(new_resolve_functions(p)); + elseif promise_settle(p, "fulfilled", next_fulfilled, p._pending_on_fulfilled, v) then + p.value = v; + end + + end + local function _reject(e) + if resolved then return; end + resolved = true; + if promise_settle(p, "rejected", next_rejected, p._pending_on_rejected, e) then + p.reason = e; + end + end + return _resolve, _reject; +end + +local function new(f) + local p = setmetatable({ _state = "pending", _next = next_pending, _pending_on_fulfilled = {}, _pending_on_rejected = {} }, promise_mt); + if f then + local resolve, reject = new_resolve_functions(p); + local ok, ret = pcall(f, resolve, reject); + if not ok and p._state == "pending" then + reject(ret); + end + end + return p; +end + +local function wrap_handler(f, resolve, reject) + return function (param) + local ok, ret = pcall(f, param); + if ok then + resolve(ret); + else + reject(ret); + end + end; +end + +function promise_methods:next(on_fulfilled, on_rejected) + return new(function (resolve, reject) + self:_next( + on_fulfilled and wrap_handler(on_fulfilled, resolve, reject) or nil, + on_rejected and wrap_handler(on_rejected, resolve, reject) or nil + ); + end); +end + +function promise_methods:catch(on_rejected) + return self:next(nil, on_rejected); +end + +local function all(promises) + return new(function (resolve, reject) + local count, total, results = 0, #promises, {}; + for i = 1, total do + promises[i]:next(function (v) + results[i] = v; + count = count + 1; + if count == total then + resolve(results); + end + end, reject); + end + end); +end + +local function race(promises) + return new(function (resolve, reject) + for i = 1, #promises do + promises[i]:next(resolve, reject); + end + end); +end + +local function resolve(v) + return new(function (_resolve) + _resolve(v); + end); +end + +local function reject(v) + return new(function (_reject) + _reject(v); + end); +end + +return { + new = new; + resolve = resolve; + reject = reject; + all = all; + race = race; +} |