aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--net/connect.lua26
-rw-r--r--net/resolvers/basic.lua160
-rw-r--r--net/resolvers/service.lua81
-rw-r--r--plugins/mod_admin_shell.lua31
-rw-r--r--spec/net_resolvers_service_spec.lua241
-rw-r--r--spec/util_poll_spec.lua35
-rw-r--r--spec/util_table_spec.lua11
-rw-r--r--util-src/table.c46
-rw-r--r--util/array.lua16
-rw-r--r--util/logger.lua16
-rw-r--r--util/prosodyctl/shell.lua8
-rw-r--r--util/stanza.lua32
12 files changed, 576 insertions, 127 deletions
diff --git a/net/connect.lua b/net/connect.lua
index 4b602be4..241cc65b 100644
--- a/net/connect.lua
+++ b/net/connect.lua
@@ -1,6 +1,7 @@
local server = require "net.server";
local log = require "util.logger".init("net.connect");
local new_id = require "util.id".short;
+local timer = require "util.timer";
-- TODO #1246 Happy Eyeballs
-- FIXME RFC 6724
@@ -28,11 +29,7 @@ local pending_connection_listeners = {};
local function attempt_connection(p)
p:log("debug", "Checking for targets...");
- if p.conn then
- pending_connections_map[p.conn] = nil;
- p.conn = nil;
- end
- p.target_resolver:next(function (conn_type, ip, port, extra)
+ p.target_resolver:next(function (conn_type, ip, port, extra, more_targets_available)
if not conn_type then
-- No more targets to try
p:log("debug", "No more connection targets to try", p.target_resolver.last_error);
@@ -49,8 +46,16 @@ local function attempt_connection(p)
p.last_error = err or "unknown reason";
return attempt_connection(p);
end
- p.conn = conn;
+ p.conns[conn] = true;
pending_connections_map[conn] = p;
+ if more_targets_available then
+ timer.add_task(0.250, function ()
+ if not p.connected then
+ p:log("debug", "Still not connected, making parallel connection attempt...");
+ attempt_connection(p);
+ end
+ end);
+ end
end);
end
@@ -62,6 +67,13 @@ function pending_connection_listeners.onconnect(conn)
return;
end
pending_connections_map[conn] = nil;
+ if p.connected then
+ -- We already succeeded in connecting
+ p.conns[conn] = nil;
+ conn:close();
+ return;
+ end
+ p.connected = true;
p:log("debug", "Successfully connected");
conn:setlistener(p.listeners, p.data);
return p.listeners.onconnect(conn);
@@ -73,6 +85,7 @@ function pending_connection_listeners.ondisconnect(conn, reason)
log("warn", "Failed connection, but unexpected!");
return;
end
+ p.conns[conn] = nil;
p.last_error = reason or "unknown reason";
p:log("debug", "Connection attempt failed: %s", p.last_error);
attempt_connection(p);
@@ -85,6 +98,7 @@ local function connect(target_resolver, listeners, options, data)
listeners = assert(listeners);
options = options or {};
data = data;
+ conns = {};
}, pending_connection_mt);
p:log("debug", "Starting connection process");
diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua
index 305bce76..15338ff4 100644
--- a/net/resolvers/basic.lua
+++ b/net/resolvers/basic.lua
@@ -2,13 +2,59 @@ local adns = require "net.adns";
local inet_pton = require "util.net".pton;
local inet_ntop = require "util.net".ntop;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
-local unpack = table.unpack or unpack; -- luacheck: ignore 113
+local promise = require "util.promise";
+local t_move = require "util.table".move;
local methods = {};
local resolver_mt = { __index = methods };
-- FIXME RFC 6724
+local function do_dns_lookup(self, dns_resolver, record_type, name)
+ return promise.new(function (resolve, reject)
+ local ipv = (record_type == "A" and "4") or (record_type == "AAAA" and "6") or nil;
+ if ipv and self.extra["use_ipv"..ipv] == false then
+ return reject(("IPv%s disabled - %s lookup skipped"):format(ipv, record_type));
+ elseif record_type == "TLSA" and self.extra.use_dane ~= true then
+ return reject("DANE disabled - TLSA lookup skipped");
+ end
+ dns_resolver:lookup(function (answer, err)
+ if not answer then
+ return reject(err);
+ elseif answer.bogus then
+ return reject(("Validation error in %s lookup"):format(record_type));
+ elseif answer.status and #answer == 0 then
+ return reject(("%s in %s lookup"):format(answer.status, record_type));
+ end
+
+ local targets = { secure = answer.secure };
+ for _, record in ipairs(answer) do
+ if ipv then
+ table.insert(targets, { self.conn_type..ipv, record[record_type:lower()], self.port, self.extra });
+ else
+ table.insert(targets, record[record_type:lower()]);
+ end
+ end
+ return resolve(targets);
+ end, name, record_type, "IN");
+ end);
+end
+
+local function merge_targets(ipv4_targets, ipv6_targets)
+ local result = { secure = ipv4_targets.secure and ipv6_targets.secure };
+ local common_length = math.min(#ipv4_targets, #ipv6_targets);
+ for i = 1, common_length do
+ table.insert(result, ipv6_targets[i]);
+ table.insert(result, ipv4_targets[i]);
+ end
+ if common_length < #ipv4_targets then
+ t_move(ipv4_targets, common_length+1, #ipv4_targets, common_length+1, result);
+ elseif common_length < #ipv6_targets then
+ t_move(ipv6_targets, common_length+1, #ipv6_targets, common_length+1, result);
+ end
+ return result;
+end
+
-- Find the next target to connect to, and
-- pass it to cb()
function methods:next(cb)
@@ -18,7 +64,7 @@ function methods:next(cb)
return;
end
local next_target = table.remove(self.targets, 1);
- cb(unpack(next_target, 1, 4));
+ cb(next_target[1], next_target[2], next_target[3], next_target[4], not not self.targets[1]);
return;
end
@@ -28,91 +74,45 @@ function methods:next(cb)
return;
end
- local secure = true;
- local tlsa = {};
- local targets = {};
- local n = 3;
- local function ready()
- n = n - 1;
- if n > 0 then return; end
- self.targets = targets;
+ -- Resolve DNS to target list
+ local dns_resolver = adns.resolver();
+
+ local dns_lookups = {
+ ipv4 = do_dns_lookup(self, dns_resolver, "A", self.hostname);
+ ipv6 = do_dns_lookup(self, dns_resolver, "AAAA", self.hostname);
+ tlsa = do_dns_lookup(self, dns_resolver, "TLSA", ("_%d._%s.%s"):format(self.port, self.conn_type, self.hostname));
+ };
+
+ promise.all_settled(dns_lookups):next(function (dns_results)
+ -- Combine targets, assign to self.targets, self:next(cb)
+ local have_ipv4 = dns_results.ipv4.status == "fulfilled";
+ local have_ipv6 = dns_results.ipv6.status == "fulfilled";
+
+ if have_ipv4 and have_ipv6 then
+ self.targets = merge_targets(dns_results.ipv4.value, dns_results.ipv6.value);
+ elseif have_ipv4 then
+ self.targets = dns_results.ipv4.value;
+ elseif have_ipv6 then
+ self.targets = dns_results.ipv6.value;
+ else
+ self.targets = {};
+ end
+
if self.extra and self.extra.use_dane then
- if secure and tlsa[1] then
- self.extra.tlsa = tlsa;
+ if self.targets.secure and dns_results.tlsa.status == "fulfilled" then
+ self.extra.tlsa = dns_results.tlsa.value;
self.extra.dane_hostname = self.hostname;
else
self.extra.tlsa = nil;
self.extra.dane_hostname = nil;
end
end
- self:next(cb);
- end
- -- Resolve DNS to target list
- local dns_resolver = adns.resolver();
-
- if not self.extra or self.extra.use_ipv4 ~= false then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra });
- end
- if answer.bogus then
- self.last_error = "Validation error in A lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in A lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, self.hostname, "A", "IN");
- else
- ready();
- end
-
- if not self.extra or self.extra.use_ipv6 ~= false then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra });
- end
- if answer.bogus then
- self.last_error = "Validation error in AAAA lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in AAAA lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, self.hostname, "AAAA", "IN");
- else
- ready();
- end
-
- if self.extra and self.extra.use_dane == true then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(tlsa, record.tlsa);
- end
- if answer.bogus then
- self.last_error = "Validation error in TLSA lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in TLSA lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN");
- else
- ready();
- end
+ self:next(cb);
+ end):catch(function (err)
+ self.last_error = err;
+ self.targets = {};
+ end);
end
local function new(hostname, port, conn_type, extra)
@@ -137,7 +137,7 @@ local function new(hostname, port, conn_type, extra)
hostname = ascii_host;
port = port;
conn_type = conn_type;
- extra = extra;
+ extra = extra or {};
targets = targets;
}, resolver_mt);
end
diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua
index 3810cac8..a7ce76a3 100644
--- a/net/resolvers/service.lua
+++ b/net/resolvers/service.lua
@@ -2,23 +2,78 @@ local adns = require "net.adns";
local basic = require "net.resolvers.basic";
local inet_pton = require "util.net".pton;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
-local unpack = table.unpack or unpack; -- luacheck: ignore 113
local methods = {};
local resolver_mt = { __index = methods };
+local function new_target_selector(rrset)
+ local rr_count = rrset and #rrset;
+ if not rr_count or rr_count == 0 then
+ rrset = nil;
+ else
+ table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end);
+ end
+ local rrset_pos = 1;
+ local priority_bucket, bucket_total_weight, bucket_len, bucket_used;
+ return function ()
+ if not rrset then return; end
+
+ if not priority_bucket or bucket_used >= bucket_len then
+ if rrset_pos > rr_count then return; end -- Used up all records
+
+ -- Going to start on a new priority now. Gather up all the next
+ -- records with the same priority and add them to priority_bucket
+ priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0;
+ local current_priority;
+ repeat
+ local curr_record = rrset[rrset_pos].srv;
+ if not current_priority then
+ current_priority = curr_record.priority;
+ elseif current_priority ~= curr_record.priority then
+ break;
+ end
+ table.insert(priority_bucket, curr_record);
+ bucket_total_weight = bucket_total_weight + curr_record.weight;
+ bucket_len = bucket_len + 1;
+ rrset_pos = rrset_pos + 1;
+ until rrset_pos > rr_count;
+ end
+
+ bucket_used = bucket_used + 1;
+ local n, running_total = math.random(0, bucket_total_weight), 0;
+ local target_record;
+ for i = 1, bucket_len do
+ local candidate = priority_bucket[i];
+ if candidate then
+ running_total = running_total + candidate.weight;
+ if running_total >= n then
+ target_record = candidate;
+ bucket_total_weight = bucket_total_weight - candidate.weight;
+ priority_bucket[i] = nil;
+ break;
+ end
+ end
+ end
+ return target_record;
+ end;
+end
+
-- Find the next target to connect to, and
-- pass it to cb()
function methods:next(cb)
- if self.targets then
- if not self.resolver then
- if #self.targets == 0 then
+ if self.resolver or self._get_next_target then
+ if not self.resolver then -- Do we have a basic resolver currently?
+ -- We don't, so fetch a new SRV target, create a new basic resolver for it
+ local next_srv_target = self._get_next_target and self._get_next_target();
+ if not next_srv_target then
+ -- No more SRV targets left
cb(nil);
return;
end
- local next_target = table.remove(self.targets, 1);
- self.resolver = basic.new(unpack(next_target, 1, 4));
+ -- Create a new basic resolver for this SRV target
+ self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra);
end
+ -- Look up the next (basic) target from the current target's resolver
self.resolver:next(function (...)
if self.resolver then
self.last_error = self.resolver.last_error;
@@ -31,6 +86,9 @@ function methods:next(cb)
end
end);
return;
+ elseif self.in_progress then
+ cb(nil);
+ return;
end
if not self.hostname then
@@ -39,9 +97,9 @@ function methods:next(cb)
return;
end
- local targets = {};
+ self.in_progress = true;
+
local function ready()
- self.targets = targets;
self:next(cb);
end
@@ -63,7 +121,7 @@ function methods:next(cb)
if #answer == 0 then
if self.extra and self.extra.default_port then
- table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra });
+ self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra);
else
self.last_error = "zero SRV records found";
end
@@ -77,10 +135,7 @@ function methods:next(cb)
return;
end
- table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end);
- for _, record in ipairs(answer) do
- table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra });
- end
+ self._get_next_target = new_target_selector(answer);
else
self.last_error = err;
end
diff --git a/plugins/mod_admin_shell.lua b/plugins/mod_admin_shell.lua
index 35124e79..9af77676 100644
--- a/plugins/mod_admin_shell.lua
+++ b/plugins/mod_admin_shell.lua
@@ -36,6 +36,7 @@ local serialization = require "util.serialization";
local serialize_config = serialization.new ({ fatal = false, unquoted = true});
local time = require "util.time";
local promise = require "util.promise";
+local logger = require "util.logger";
local t_insert = table.insert;
local t_concat = table.concat;
@@ -83,8 +84,8 @@ function runner_callbacks:error(err)
self.data.print("Error: "..tostring(err));
end
-local function send_repl_output(session, line)
- return session.send(st.stanza("repl-output"):text(tostring(line)));
+local function send_repl_output(session, line, attr)
+ return session.send(st.stanza("repl-output", attr):text(tostring(line)));
end
function console:new_session(admin_session)
@@ -99,8 +100,14 @@ function console:new_session(admin_session)
end
return send_repl_output(admin_session, table.concat(t, "\t"));
end;
+ write = function (t)
+ return send_repl_output(admin_session, t, { eol = "0" });
+ end;
serialize = tostring;
disconnect = function () admin_session:close(); end;
+ is_connected = function ()
+ return not not admin_session.conn;
+ end
};
session.env = setmetatable({}, default_env_mt);
@@ -1583,6 +1590,26 @@ function def_env.http:list(hosts)
return true;
end
+def_env.watch = {};
+
+function def_env.watch:log()
+ local writing = false;
+ local sink = logger.add_simple_sink(function (source, level, message)
+ if writing then return; end
+ writing = true;
+ self.session.print(source, level, message);
+ writing = false;
+ end);
+
+ while self.session.is_connected() do
+ async.sleep(3);
+ end
+ if not logger.remove_sink(sink) then
+ module:log("warn", "Unable to remove watch:log() sink");
+ end
+end
+
+
def_env.debug = {};
function def_env.debug:logevents(host)
diff --git a/spec/net_resolvers_service_spec.lua b/spec/net_resolvers_service_spec.lua
new file mode 100644
index 00000000..53ce4754
--- /dev/null
+++ b/spec/net_resolvers_service_spec.lua
@@ -0,0 +1,241 @@
+local set = require "util.set";
+
+insulate("net.resolvers.service", function ()
+ local adns = {
+ resolver = function ()
+ return {
+ lookup = function (_, cb, qname, qtype, qclass)
+ if qname == "_xmpp-server._tcp.example.com"
+ and (qtype or "SRV") == "SRV"
+ and (qclass or "IN") == "IN" then
+ cb({
+ { -- 60+35+60
+ srv = { target = "xmpp0-a.example.com", port = 5228, priority = 0, weight = 60 };
+ };
+ {
+ srv = { target = "xmpp0-b.example.com", port = 5216, priority = 0, weight = 35 };
+ };
+ {
+ srv = { target = "xmpp0-c.example.com", port = 5200, priority = 0, weight = 0 };
+ };
+ {
+ srv = { target = "xmpp0-d.example.com", port = 5256, priority = 0, weight = 120 };
+ };
+
+ {
+ srv = { target = "xmpp1-a.example.com", port = 5273, priority = 1, weight = 30 };
+ };
+ {
+ srv = { target = "xmpp1-b.example.com", port = 5274, priority = 1, weight = 30 };
+ };
+
+ {
+ srv = { target = "xmpp2.example.com", port = 5275, priority = 2, weight = 0 };
+ };
+ });
+ elseif qname == "_xmpp-server._tcp.single.example.com"
+ and (qtype or "SRV") == "SRV"
+ and (qclass or "IN") == "IN" then
+ cb({
+ {
+ srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 };
+ };
+ });
+ elseif qname == "_xmpp-server._tcp.half.example.com"
+ and (qtype or "SRV") == "SRV"
+ and (qclass or "IN") == "IN" then
+ cb({
+ {
+ srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 };
+ };
+ {
+ srv = { target = "xmpp0-b.example.com", port = 5270, priority = 0, weight = 1 };
+ };
+ });
+ elseif qtype == "A" then
+ local l = qname:match("%-(%a)%.example.com$") or "1";
+ local d = ("%d"):format(l:byte())
+ cb({
+ {
+ a = "127.0.0."..d;
+ };
+ });
+ elseif qtype == "AAAA" then
+ local l = qname:match("%-(%a)%.example.com$") or "1";
+ local d = ("%04d"):format(l:byte())
+ cb({
+ {
+ aaaa = "fdeb:9619:649e:c7d9::"..d;
+ };
+ });
+ else
+ cb(nil);
+ end
+ end;
+ };
+ end;
+ };
+ package.loaded["net.adns"] = mock(adns);
+ local resolver = require "net.resolvers.service";
+ math.randomseed(os.time());
+ it("works for 99% of deployments", function ()
+ -- Most deployments only have a single SRV record, let's make
+ -- sure that works okay
+
+ local expected_targets = set.new({
+ -- xmpp0-a
+ "tcp4 127.0.0.97 5269";
+ "tcp6 fdeb:9619:649e:c7d9::0097 5269";
+ });
+ local received_targets = set.new({});
+
+ local r = resolver.new("single.example.com", "xmpp-server");
+ local done = false;
+ local function handle_target(...)
+ if ... == nil then
+ done = true;
+ -- No more targets
+ return;
+ end
+ received_targets:add(table.concat({ ... }, " ", 1, 3));
+ end
+ r:next(handle_target);
+ while not done do
+ r:next(handle_target);
+ end
+
+ -- We should have received all expected targets, and no unexpected
+ -- ones:
+ assert.truthy(set.xor(received_targets, expected_targets):empty());
+ end);
+
+ it("supports A/AAAA fallback", function ()
+ -- Many deployments don't have any SRV records, so we should
+ -- fall back to A/AAAA records instead when that is the case
+
+ local expected_targets = set.new({
+ -- xmpp0-a
+ "tcp4 127.0.0.97 5269";
+ "tcp6 fdeb:9619:649e:c7d9::0097 5269";
+ });
+ local received_targets = set.new({});
+
+ local r = resolver.new("xmpp0-a.example.com", "xmpp-server", "tcp", { default_port = 5269 });
+ local done = false;
+ local function handle_target(...)
+ if ... == nil then
+ done = true;
+ -- No more targets
+ return;
+ end
+ received_targets:add(table.concat({ ... }, " ", 1, 3));
+ end
+ r:next(handle_target);
+ while not done do
+ r:next(handle_target);
+ end
+
+ -- We should have received all expected targets, and no unexpected
+ -- ones:
+ assert.truthy(set.xor(received_targets, expected_targets):empty());
+ end);
+
+
+ it("works", function ()
+ local expected_targets = set.new({
+ -- xmpp0-a
+ "tcp4 127.0.0.97 5228";
+ "tcp6 fdeb:9619:649e:c7d9::0097 5228";
+ "tcp4 127.0.0.97 5273";
+ "tcp6 fdeb:9619:649e:c7d9::0097 5273";
+
+ -- xmpp0-b
+ "tcp4 127.0.0.98 5274";
+ "tcp6 fdeb:9619:649e:c7d9::0098 5274";
+ "tcp4 127.0.0.98 5216";
+ "tcp6 fdeb:9619:649e:c7d9::0098 5216";
+
+ -- xmpp0-c
+ "tcp4 127.0.0.99 5200";
+ "tcp6 fdeb:9619:649e:c7d9::0099 5200";
+
+ -- xmpp0-d
+ "tcp4 127.0.0.100 5256";
+ "tcp6 fdeb:9619:649e:c7d9::0100 5256";
+
+ -- xmpp2
+ "tcp4 127.0.0.49 5275";
+ "tcp6 fdeb:9619:649e:c7d9::0049 5275";
+
+ });
+ local received_targets = set.new({});
+
+ local r = resolver.new("example.com", "xmpp-server");
+ local done = false;
+ local function handle_target(...)
+ if ... == nil then
+ done = true;
+ -- No more targets
+ return;
+ end
+ received_targets:add(table.concat({ ... }, " ", 1, 3));
+ end
+ r:next(handle_target);
+ while not done do
+ r:next(handle_target);
+ end
+
+ -- We should have received all expected targets, and no unexpected
+ -- ones:
+ assert.truthy(set.xor(received_targets, expected_targets):empty());
+ end);
+
+ it("balances across weights correctly #slow", function ()
+ -- This mimics many repeated connections to 'example.com' (mock
+ -- records defined above), and records the port number of the
+ -- first target. Therefore it (should) only return priority
+ -- 0 records, and the input data is constructed such that the
+ -- last two digits of the port number represent the percentage
+ -- of times that record should (on average) be picked first.
+
+ -- To prevent random test failures, we test across a handful
+ -- of fixed (randomly selected) seeds.
+ for _, seed in ipairs({ 8401877, 3943829, 7830992 }) do
+ math.randomseed(seed);
+
+ local results = {};
+ local function run()
+ local run_results = {};
+ local r = resolver.new("example.com", "xmpp-server");
+ local function record_target(...)
+ if ... == nil then
+ -- No more targets
+ return;
+ end
+ run_results = { ... };
+ end
+ r:next(record_target);
+ return run_results[3];
+ end
+
+ for _ = 1, 1000 do
+ local port = run();
+ results[port] = (results[port] or 0) + 1;
+ end
+
+ local ports = {};
+ for port in pairs(results) do
+ table.insert(ports, port);
+ end
+ table.sort(ports);
+ for _, port in ipairs(ports) do
+ --print("PORT", port, tostring((results[port]/1000) * 100).."% hits (expected "..tostring(port-5200).."%)");
+ local hit_pct = (results[port]/1000) * 100;
+ local expected_pct = port - 5200;
+ --print(hit_pct, expected_pct, math.abs(hit_pct - expected_pct));
+ assert.is_true(math.abs(hit_pct - expected_pct) < 5);
+ end
+ --print("---");
+ end
+ end);
+end);
diff --git a/spec/util_poll_spec.lua b/spec/util_poll_spec.lua
index a763be90..05318453 100644
--- a/spec/util_poll_spec.lua
+++ b/spec/util_poll_spec.lua
@@ -1,6 +1,35 @@
-describe("util.poll", function ()
- it("loads", function ()
- require "util.poll"
+describe("util.poll", function()
+ local poll;
+ setup(function()
+ poll = require "util.poll";
end);
+ it("loads", function()
+ assert.is_table(poll);
+ assert.is_function(poll.new);
+ assert.is_string(poll.api);
+ end);
+ describe("new", function()
+ local p;
+ setup(function()
+ p = poll.new();
+ end)
+ it("times out", function ()
+ local fd, err = p:wait(0);
+ assert.falsy(fd);
+ assert.equal("timeout", err);
+ end);
+ it("works", function()
+ -- stdout should be writable, right?
+ assert.truthy(p:add(1, false, true));
+ local fd, r, w = p:wait(1);
+ assert.is_number(fd);
+ assert.is_boolean(r);
+ assert.is_boolean(w);
+ assert.equal(1, fd);
+ assert.falsy(r);
+ assert.truthy(w);
+ assert.truthy(p:del(1));
+ end);
+ end)
end);
diff --git a/spec/util_table_spec.lua b/spec/util_table_spec.lua
index 76f54b69..a0535c08 100644
--- a/spec/util_table_spec.lua
+++ b/spec/util_table_spec.lua
@@ -12,6 +12,17 @@ describe("util.table", function ()
assert.same({ "lorem", "ipsum", "dolor", "sit", "amet", n = 5 }, u_table.pack("lorem", "ipsum", "dolor", "sit", "amet"));
end);
end);
+
+ describe("move()", function ()
+ it("works", function ()
+ local t1 = { "apple", "banana", "carrot" };
+ local t2 = { "cat", "donkey", "elephant" };
+ local t3 = {};
+ u_table.move(t1, 1, 3, 1, t3);
+ u_table.move(t2, 1, 3, 3, t3);
+ assert.same({ "apple", "banana", "cat", "donkey", "elephant" }, t3);
+ end);
+ end);
end);
diff --git a/util-src/table.c b/util-src/table.c
index 9a9553fc..4bbceedb 100644
--- a/util-src/table.c
+++ b/util-src/table.c
@@ -1,11 +1,21 @@
#include <lua.h>
#include <lauxlib.h>
+#ifndef LUA_MAXINTEGER
+#include <stdint.h>
+#define LUA_MAXINTEGER PTRDIFF_MAX
+#endif
+
+#if (LUA_VERSION_NUM > 501)
+#define lua_equal(L, A, B) lua_compare(L, A, B, LUA_OPEQ)
+#endif
+
static int Lcreate_table(lua_State *L) {
lua_createtable(L, luaL_checkinteger(L, 1), luaL_checkinteger(L, 2));
return 1;
}
+/* COMPAT: w/ Lua pre-5.4 */
static int Lpack(lua_State *L) {
unsigned int n_args = lua_gettop(L);
lua_createtable(L, n_args, 1);
@@ -20,6 +30,40 @@ static int Lpack(lua_State *L) {
return 1;
}
+/* COMPAT: w/ Lua pre-5.4 */
+static int Lmove (lua_State *L) {
+ lua_Integer f = luaL_checkinteger(L, 2);
+ lua_Integer e = luaL_checkinteger(L, 3);
+ lua_Integer t = luaL_checkinteger(L, 4);
+
+ int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */
+ luaL_checktype(L, 1, LUA_TTABLE);
+ luaL_checktype(L, tt, LUA_TTABLE);
+
+ if (e >= f) { /* otherwise, nothing to move */
+ lua_Integer n, i;
+ luaL_argcheck(L, f > 0 || e < LUA_MAXINTEGER + f, 3,
+ "too many elements to move");
+ n = e - f + 1; /* number of elements to move */
+ luaL_argcheck(L, t <= LUA_MAXINTEGER - n + 1, 4,
+ "destination wrap around");
+ if (t > e || t <= f || (tt != 1 && !lua_equal(L, 1, tt))) {
+ for (i = 0; i < n; i++) {
+ lua_rawgeti(L, 1, f + i);
+ lua_rawseti(L, tt, t + i);
+ }
+ } else {
+ for (i = n - 1; i >= 0; i--) {
+ lua_rawgeti(L, 1, f + i);
+ lua_rawseti(L, tt, t + i);
+ }
+ }
+ }
+
+ lua_pushvalue(L, tt); /* return destination table */
+ return 1;
+}
+
int luaopen_util_table(lua_State *L) {
#if (LUA_VERSION_NUM > 501)
luaL_checkversion(L);
@@ -29,5 +73,7 @@ int luaopen_util_table(lua_State *L) {
lua_setfield(L, -2, "create");
lua_pushcfunction(L, Lpack);
lua_setfield(L, -2, "pack");
+ lua_pushcfunction(L, Lmove);
+ lua_setfield(L, -2, "move");
return 1;
}
diff --git a/util/array.lua b/util/array.lua
index c33a5ef1..9d438940 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -8,6 +8,7 @@
local t_insert, t_sort, t_remove, t_concat
= table.insert, table.sort, table.remove, table.concat;
+local t_move = require "util.table".move;
local setmetatable = setmetatable;
local getmetatable = getmetatable;
@@ -137,13 +138,11 @@ function array_base.slice(outa, ina, i, j)
return outa;
end
- for idx = 1, 1+j-i do
- outa[idx] = ina[i+(idx-1)];
- end
+
+ t_move(ina, i, j, 1, outa);
if ina == outa then
- for idx = 2+j-i, #outa do
- outa[idx] = nil;
- end
+ -- Clear (nil) remainder of range
+ t_move(ina, #outa+1, #outa*2, 2+j-i, ina);
end
return outa;
end
@@ -209,10 +208,7 @@ function array_methods:shuffle()
end
function array_methods:append(ina)
- local len, len2 = #self, #ina;
- for i = 1, len2 do
- self[len+i] = ina[i];
- end
+ t_move(ina, 1, #ina, #self+1, self);
return self;
end
diff --git a/util/logger.lua b/util/logger.lua
index 20a5cef2..148b98dc 100644
--- a/util/logger.lua
+++ b/util/logger.lua
@@ -10,6 +10,7 @@
local pairs = pairs;
local ipairs = ipairs;
local require = require;
+local t_remove = table.remove;
local _ENV = nil;
-- luacheck: std none
@@ -78,6 +79,20 @@ local function add_simple_sink(simple_sink_function, levels)
for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do
add_level_sink(level, sink_function);
end
+ return sink_function;
+end
+
+local function remove_sink(sink_function)
+ local removed;
+ for level, sinks in pairs(level_sinks) do
+ for i = #sinks, 1, -1 do
+ if sinks[i] == sink_function then
+ t_remove(sinks, i);
+ removed = true;
+ end
+ end
+ end
+ return removed;
end
return {
@@ -87,4 +102,5 @@ return {
add_level_sink = add_level_sink;
add_simple_sink = add_simple_sink;
new = make_logger;
+ remove_sink = remove_sink;
};
diff --git a/util/prosodyctl/shell.lua b/util/prosodyctl/shell.lua
index bce27b94..0b1dd3f9 100644
--- a/util/prosodyctl/shell.lua
+++ b/util/prosodyctl/shell.lua
@@ -89,11 +89,15 @@ local function start(arg) --luacheck: ignore 212/arg
local errors = 0; -- TODO This is weird, but works for now.
client.events.add_handler("received", function(stanza)
if stanza.name == "repl-output" or stanza.name == "repl-result" then
+ local dest = io.stdout;
if stanza.attr.type == "error" then
errors = errors + 1;
- io.stderr:write(stanza:get_text(), "\n");
+ dest = io.stderr;
+ end
+ if stanza.attr.eol == "0" then
+ dest:write(stanza:get_text());
else
- print(stanza:get_text());
+ dest:write(stanza:get_text(), "\n");
end
end
if stanza.name == "repl-result" then
diff --git a/util/stanza.lua b/util/stanza.lua
index a38f80b3..a14be5f0 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -21,6 +21,8 @@ local type = type;
local s_gsub = string.gsub;
local s_sub = string.sub;
local s_find = string.find;
+local t_move = table.move or require "util.table".move;
+local t_create = require"util.table".create;
local valid_utf8 = require "util.encodings".utf8.valid;
@@ -275,25 +277,33 @@ function stanza_mt:find(path)
end
local function _clone(stanza, only_top)
- local attr, tags = {}, {};
+ local attr = {};
for k,v in pairs(stanza.attr) do attr[k] = v; end
local old_namespaces, namespaces = stanza.namespaces;
if old_namespaces then
namespaces = {};
for k,v in pairs(old_namespaces) do namespaces[k] = v; end
end
- local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
+ local tags, new;
+ if only_top then
+ tags = {};
+ new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
+ else
+ tags = t_create(#stanza.tags, 0);
+ new = t_create(#stanza, 4);
+ new.name = stanza.name;
+ new.attr = attr;
+ new.namespaces = namespaces;
+ new.tags = tags;
+ end
+
+ setmetatable(new, stanza_mt);
if not only_top then
- for i=1,#stanza do
- local child = stanza[i];
- if child.name then
- child = _clone(child);
- t_insert(tags, child);
- end
- t_insert(new, child);
- end
+ t_move(stanza, 1, #stanza, 1, new);
+ t_move(stanza.tags, 1, #stanza.tags, 1, tags);
+ new:maptags(_clone);
end
- return setmetatable(new, stanza_mt);
+ return new;
end
local function clone(stanza, only_top)