diff options
-rw-r--r-- | net/connect.lua | 26 | ||||
-rw-r--r-- | net/resolvers/basic.lua | 160 | ||||
-rw-r--r-- | net/resolvers/service.lua | 81 | ||||
-rw-r--r-- | plugins/mod_admin_shell.lua | 31 | ||||
-rw-r--r-- | spec/net_resolvers_service_spec.lua | 241 | ||||
-rw-r--r-- | spec/util_poll_spec.lua | 35 | ||||
-rw-r--r-- | spec/util_table_spec.lua | 11 | ||||
-rw-r--r-- | util-src/table.c | 46 | ||||
-rw-r--r-- | util/array.lua | 16 | ||||
-rw-r--r-- | util/logger.lua | 16 | ||||
-rw-r--r-- | util/prosodyctl/shell.lua | 8 | ||||
-rw-r--r-- | util/stanza.lua | 32 |
12 files changed, 576 insertions, 127 deletions
diff --git a/net/connect.lua b/net/connect.lua index 4b602be4..241cc65b 100644 --- a/net/connect.lua +++ b/net/connect.lua @@ -1,6 +1,7 @@ local server = require "net.server"; local log = require "util.logger".init("net.connect"); local new_id = require "util.id".short; +local timer = require "util.timer"; -- TODO #1246 Happy Eyeballs -- FIXME RFC 6724 @@ -28,11 +29,7 @@ local pending_connection_listeners = {}; local function attempt_connection(p) p:log("debug", "Checking for targets..."); - if p.conn then - pending_connections_map[p.conn] = nil; - p.conn = nil; - end - p.target_resolver:next(function (conn_type, ip, port, extra) + p.target_resolver:next(function (conn_type, ip, port, extra, more_targets_available) if not conn_type then -- No more targets to try p:log("debug", "No more connection targets to try", p.target_resolver.last_error); @@ -49,8 +46,16 @@ local function attempt_connection(p) p.last_error = err or "unknown reason"; return attempt_connection(p); end - p.conn = conn; + p.conns[conn] = true; pending_connections_map[conn] = p; + if more_targets_available then + timer.add_task(0.250, function () + if not p.connected then + p:log("debug", "Still not connected, making parallel connection attempt..."); + attempt_connection(p); + end + end); + end end); end @@ -62,6 +67,13 @@ function pending_connection_listeners.onconnect(conn) return; end pending_connections_map[conn] = nil; + if p.connected then + -- We already succeeded in connecting + p.conns[conn] = nil; + conn:close(); + return; + end + p.connected = true; p:log("debug", "Successfully connected"); conn:setlistener(p.listeners, p.data); return p.listeners.onconnect(conn); @@ -73,6 +85,7 @@ function pending_connection_listeners.ondisconnect(conn, reason) log("warn", "Failed connection, but unexpected!"); return; end + p.conns[conn] = nil; p.last_error = reason or "unknown reason"; p:log("debug", "Connection attempt failed: %s", p.last_error); attempt_connection(p); @@ -85,6 +98,7 @@ local function connect(target_resolver, listeners, options, data) listeners = assert(listeners); options = options or {}; data = data; + conns = {}; }, pending_connection_mt); p:log("debug", "Starting connection process"); diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua index 305bce76..15338ff4 100644 --- a/net/resolvers/basic.lua +++ b/net/resolvers/basic.lua @@ -2,13 +2,59 @@ local adns = require "net.adns"; local inet_pton = require "util.net".pton; local inet_ntop = require "util.net".ntop; local idna_to_ascii = require "util.encodings".idna.to_ascii; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local promise = require "util.promise"; +local t_move = require "util.table".move; local methods = {}; local resolver_mt = { __index = methods }; -- FIXME RFC 6724 +local function do_dns_lookup(self, dns_resolver, record_type, name) + return promise.new(function (resolve, reject) + local ipv = (record_type == "A" and "4") or (record_type == "AAAA" and "6") or nil; + if ipv and self.extra["use_ipv"..ipv] == false then + return reject(("IPv%s disabled - %s lookup skipped"):format(ipv, record_type)); + elseif record_type == "TLSA" and self.extra.use_dane ~= true then + return reject("DANE disabled - TLSA lookup skipped"); + end + dns_resolver:lookup(function (answer, err) + if not answer then + return reject(err); + elseif answer.bogus then + return reject(("Validation error in %s lookup"):format(record_type)); + elseif answer.status and #answer == 0 then + return reject(("%s in %s lookup"):format(answer.status, record_type)); + end + + local targets = { secure = answer.secure }; + for _, record in ipairs(answer) do + if ipv then + table.insert(targets, { self.conn_type..ipv, record[record_type:lower()], self.port, self.extra }); + else + table.insert(targets, record[record_type:lower()]); + end + end + return resolve(targets); + end, name, record_type, "IN"); + end); +end + +local function merge_targets(ipv4_targets, ipv6_targets) + local result = { secure = ipv4_targets.secure and ipv6_targets.secure }; + local common_length = math.min(#ipv4_targets, #ipv6_targets); + for i = 1, common_length do + table.insert(result, ipv6_targets[i]); + table.insert(result, ipv4_targets[i]); + end + if common_length < #ipv4_targets then + t_move(ipv4_targets, common_length+1, #ipv4_targets, common_length+1, result); + elseif common_length < #ipv6_targets then + t_move(ipv6_targets, common_length+1, #ipv6_targets, common_length+1, result); + end + return result; +end + -- Find the next target to connect to, and -- pass it to cb() function methods:next(cb) @@ -18,7 +64,7 @@ function methods:next(cb) return; end local next_target = table.remove(self.targets, 1); - cb(unpack(next_target, 1, 4)); + cb(next_target[1], next_target[2], next_target[3], next_target[4], not not self.targets[1]); return; end @@ -28,91 +74,45 @@ function methods:next(cb) return; end - local secure = true; - local tlsa = {}; - local targets = {}; - local n = 3; - local function ready() - n = n - 1; - if n > 0 then return; end - self.targets = targets; + -- Resolve DNS to target list + local dns_resolver = adns.resolver(); + + local dns_lookups = { + ipv4 = do_dns_lookup(self, dns_resolver, "A", self.hostname); + ipv6 = do_dns_lookup(self, dns_resolver, "AAAA", self.hostname); + tlsa = do_dns_lookup(self, dns_resolver, "TLSA", ("_%d._%s.%s"):format(self.port, self.conn_type, self.hostname)); + }; + + promise.all_settled(dns_lookups):next(function (dns_results) + -- Combine targets, assign to self.targets, self:next(cb) + local have_ipv4 = dns_results.ipv4.status == "fulfilled"; + local have_ipv6 = dns_results.ipv6.status == "fulfilled"; + + if have_ipv4 and have_ipv6 then + self.targets = merge_targets(dns_results.ipv4.value, dns_results.ipv6.value); + elseif have_ipv4 then + self.targets = dns_results.ipv4.value; + elseif have_ipv6 then + self.targets = dns_results.ipv6.value; + else + self.targets = {}; + end + if self.extra and self.extra.use_dane then - if secure and tlsa[1] then - self.extra.tlsa = tlsa; + if self.targets.secure and dns_results.tlsa.status == "fulfilled" then + self.extra.tlsa = dns_results.tlsa.value; self.extra.dane_hostname = self.hostname; else self.extra.tlsa = nil; self.extra.dane_hostname = nil; end end - self:next(cb); - end - -- Resolve DNS to target list - local dns_resolver = adns.resolver(); - - if not self.extra or self.extra.use_ipv4 ~= false then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra }); - end - if answer.bogus then - self.last_error = "Validation error in A lookup"; - elseif answer.status then - self.last_error = answer.status .. " in A lookup"; - end - else - self.last_error = err; - end - ready(); - end, self.hostname, "A", "IN"); - else - ready(); - end - - if not self.extra or self.extra.use_ipv6 ~= false then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra }); - end - if answer.bogus then - self.last_error = "Validation error in AAAA lookup"; - elseif answer.status then - self.last_error = answer.status .. " in AAAA lookup"; - end - else - self.last_error = err; - end - ready(); - end, self.hostname, "AAAA", "IN"); - else - ready(); - end - - if self.extra and self.extra.use_dane == true then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(tlsa, record.tlsa); - end - if answer.bogus then - self.last_error = "Validation error in TLSA lookup"; - elseif answer.status then - self.last_error = answer.status .. " in TLSA lookup"; - end - else - self.last_error = err; - end - ready(); - end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN"); - else - ready(); - end + self:next(cb); + end):catch(function (err) + self.last_error = err; + self.targets = {}; + end); end local function new(hostname, port, conn_type, extra) @@ -137,7 +137,7 @@ local function new(hostname, port, conn_type, extra) hostname = ascii_host; port = port; conn_type = conn_type; - extra = extra; + extra = extra or {}; targets = targets; }, resolver_mt); end diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua index 3810cac8..a7ce76a3 100644 --- a/net/resolvers/service.lua +++ b/net/resolvers/service.lua @@ -2,23 +2,78 @@ local adns = require "net.adns"; local basic = require "net.resolvers.basic"; local inet_pton = require "util.net".pton; local idna_to_ascii = require "util.encodings".idna.to_ascii; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 local methods = {}; local resolver_mt = { __index = methods }; +local function new_target_selector(rrset) + local rr_count = rrset and #rrset; + if not rr_count or rr_count == 0 then + rrset = nil; + else + table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end); + end + local rrset_pos = 1; + local priority_bucket, bucket_total_weight, bucket_len, bucket_used; + return function () + if not rrset then return; end + + if not priority_bucket or bucket_used >= bucket_len then + if rrset_pos > rr_count then return; end -- Used up all records + + -- Going to start on a new priority now. Gather up all the next + -- records with the same priority and add them to priority_bucket + priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0; + local current_priority; + repeat + local curr_record = rrset[rrset_pos].srv; + if not current_priority then + current_priority = curr_record.priority; + elseif current_priority ~= curr_record.priority then + break; + end + table.insert(priority_bucket, curr_record); + bucket_total_weight = bucket_total_weight + curr_record.weight; + bucket_len = bucket_len + 1; + rrset_pos = rrset_pos + 1; + until rrset_pos > rr_count; + end + + bucket_used = bucket_used + 1; + local n, running_total = math.random(0, bucket_total_weight), 0; + local target_record; + for i = 1, bucket_len do + local candidate = priority_bucket[i]; + if candidate then + running_total = running_total + candidate.weight; + if running_total >= n then + target_record = candidate; + bucket_total_weight = bucket_total_weight - candidate.weight; + priority_bucket[i] = nil; + break; + end + end + end + return target_record; + end; +end + -- Find the next target to connect to, and -- pass it to cb() function methods:next(cb) - if self.targets then - if not self.resolver then - if #self.targets == 0 then + if self.resolver or self._get_next_target then + if not self.resolver then -- Do we have a basic resolver currently? + -- We don't, so fetch a new SRV target, create a new basic resolver for it + local next_srv_target = self._get_next_target and self._get_next_target(); + if not next_srv_target then + -- No more SRV targets left cb(nil); return; end - local next_target = table.remove(self.targets, 1); - self.resolver = basic.new(unpack(next_target, 1, 4)); + -- Create a new basic resolver for this SRV target + self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra); end + -- Look up the next (basic) target from the current target's resolver self.resolver:next(function (...) if self.resolver then self.last_error = self.resolver.last_error; @@ -31,6 +86,9 @@ function methods:next(cb) end end); return; + elseif self.in_progress then + cb(nil); + return; end if not self.hostname then @@ -39,9 +97,9 @@ function methods:next(cb) return; end - local targets = {}; + self.in_progress = true; + local function ready() - self.targets = targets; self:next(cb); end @@ -63,7 +121,7 @@ function methods:next(cb) if #answer == 0 then if self.extra and self.extra.default_port then - table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra }); + self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra); else self.last_error = "zero SRV records found"; end @@ -77,10 +135,7 @@ function methods:next(cb) return; end - table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end); - for _, record in ipairs(answer) do - table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra }); - end + self._get_next_target = new_target_selector(answer); else self.last_error = err; end diff --git a/plugins/mod_admin_shell.lua b/plugins/mod_admin_shell.lua index 35124e79..9af77676 100644 --- a/plugins/mod_admin_shell.lua +++ b/plugins/mod_admin_shell.lua @@ -36,6 +36,7 @@ local serialization = require "util.serialization"; local serialize_config = serialization.new ({ fatal = false, unquoted = true}); local time = require "util.time"; local promise = require "util.promise"; +local logger = require "util.logger"; local t_insert = table.insert; local t_concat = table.concat; @@ -83,8 +84,8 @@ function runner_callbacks:error(err) self.data.print("Error: "..tostring(err)); end -local function send_repl_output(session, line) - return session.send(st.stanza("repl-output"):text(tostring(line))); +local function send_repl_output(session, line, attr) + return session.send(st.stanza("repl-output", attr):text(tostring(line))); end function console:new_session(admin_session) @@ -99,8 +100,14 @@ function console:new_session(admin_session) end return send_repl_output(admin_session, table.concat(t, "\t")); end; + write = function (t) + return send_repl_output(admin_session, t, { eol = "0" }); + end; serialize = tostring; disconnect = function () admin_session:close(); end; + is_connected = function () + return not not admin_session.conn; + end }; session.env = setmetatable({}, default_env_mt); @@ -1583,6 +1590,26 @@ function def_env.http:list(hosts) return true; end +def_env.watch = {}; + +function def_env.watch:log() + local writing = false; + local sink = logger.add_simple_sink(function (source, level, message) + if writing then return; end + writing = true; + self.session.print(source, level, message); + writing = false; + end); + + while self.session.is_connected() do + async.sleep(3); + end + if not logger.remove_sink(sink) then + module:log("warn", "Unable to remove watch:log() sink"); + end +end + + def_env.debug = {}; function def_env.debug:logevents(host) diff --git a/spec/net_resolvers_service_spec.lua b/spec/net_resolvers_service_spec.lua new file mode 100644 index 00000000..53ce4754 --- /dev/null +++ b/spec/net_resolvers_service_spec.lua @@ -0,0 +1,241 @@ +local set = require "util.set"; + +insulate("net.resolvers.service", function () + local adns = { + resolver = function () + return { + lookup = function (_, cb, qname, qtype, qclass) + if qname == "_xmpp-server._tcp.example.com" + and (qtype or "SRV") == "SRV" + and (qclass or "IN") == "IN" then + cb({ + { -- 60+35+60 + srv = { target = "xmpp0-a.example.com", port = 5228, priority = 0, weight = 60 }; + }; + { + srv = { target = "xmpp0-b.example.com", port = 5216, priority = 0, weight = 35 }; + }; + { + srv = { target = "xmpp0-c.example.com", port = 5200, priority = 0, weight = 0 }; + }; + { + srv = { target = "xmpp0-d.example.com", port = 5256, priority = 0, weight = 120 }; + }; + + { + srv = { target = "xmpp1-a.example.com", port = 5273, priority = 1, weight = 30 }; + }; + { + srv = { target = "xmpp1-b.example.com", port = 5274, priority = 1, weight = 30 }; + }; + + { + srv = { target = "xmpp2.example.com", port = 5275, priority = 2, weight = 0 }; + }; + }); + elseif qname == "_xmpp-server._tcp.single.example.com" + and (qtype or "SRV") == "SRV" + and (qclass or "IN") == "IN" then + cb({ + { + srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 }; + }; + }); + elseif qname == "_xmpp-server._tcp.half.example.com" + and (qtype or "SRV") == "SRV" + and (qclass or "IN") == "IN" then + cb({ + { + srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 }; + }; + { + srv = { target = "xmpp0-b.example.com", port = 5270, priority = 0, weight = 1 }; + }; + }); + elseif qtype == "A" then + local l = qname:match("%-(%a)%.example.com$") or "1"; + local d = ("%d"):format(l:byte()) + cb({ + { + a = "127.0.0."..d; + }; + }); + elseif qtype == "AAAA" then + local l = qname:match("%-(%a)%.example.com$") or "1"; + local d = ("%04d"):format(l:byte()) + cb({ + { + aaaa = "fdeb:9619:649e:c7d9::"..d; + }; + }); + else + cb(nil); + end + end; + }; + end; + }; + package.loaded["net.adns"] = mock(adns); + local resolver = require "net.resolvers.service"; + math.randomseed(os.time()); + it("works for 99% of deployments", function () + -- Most deployments only have a single SRV record, let's make + -- sure that works okay + + local expected_targets = set.new({ + -- xmpp0-a + "tcp4 127.0.0.97 5269"; + "tcp6 fdeb:9619:649e:c7d9::0097 5269"; + }); + local received_targets = set.new({}); + + local r = resolver.new("single.example.com", "xmpp-server"); + local done = false; + local function handle_target(...) + if ... == nil then + done = true; + -- No more targets + return; + end + received_targets:add(table.concat({ ... }, " ", 1, 3)); + end + r:next(handle_target); + while not done do + r:next(handle_target); + end + + -- We should have received all expected targets, and no unexpected + -- ones: + assert.truthy(set.xor(received_targets, expected_targets):empty()); + end); + + it("supports A/AAAA fallback", function () + -- Many deployments don't have any SRV records, so we should + -- fall back to A/AAAA records instead when that is the case + + local expected_targets = set.new({ + -- xmpp0-a + "tcp4 127.0.0.97 5269"; + "tcp6 fdeb:9619:649e:c7d9::0097 5269"; + }); + local received_targets = set.new({}); + + local r = resolver.new("xmpp0-a.example.com", "xmpp-server", "tcp", { default_port = 5269 }); + local done = false; + local function handle_target(...) + if ... == nil then + done = true; + -- No more targets + return; + end + received_targets:add(table.concat({ ... }, " ", 1, 3)); + end + r:next(handle_target); + while not done do + r:next(handle_target); + end + + -- We should have received all expected targets, and no unexpected + -- ones: + assert.truthy(set.xor(received_targets, expected_targets):empty()); + end); + + + it("works", function () + local expected_targets = set.new({ + -- xmpp0-a + "tcp4 127.0.0.97 5228"; + "tcp6 fdeb:9619:649e:c7d9::0097 5228"; + "tcp4 127.0.0.97 5273"; + "tcp6 fdeb:9619:649e:c7d9::0097 5273"; + + -- xmpp0-b + "tcp4 127.0.0.98 5274"; + "tcp6 fdeb:9619:649e:c7d9::0098 5274"; + "tcp4 127.0.0.98 5216"; + "tcp6 fdeb:9619:649e:c7d9::0098 5216"; + + -- xmpp0-c + "tcp4 127.0.0.99 5200"; + "tcp6 fdeb:9619:649e:c7d9::0099 5200"; + + -- xmpp0-d + "tcp4 127.0.0.100 5256"; + "tcp6 fdeb:9619:649e:c7d9::0100 5256"; + + -- xmpp2 + "tcp4 127.0.0.49 5275"; + "tcp6 fdeb:9619:649e:c7d9::0049 5275"; + + }); + local received_targets = set.new({}); + + local r = resolver.new("example.com", "xmpp-server"); + local done = false; + local function handle_target(...) + if ... == nil then + done = true; + -- No more targets + return; + end + received_targets:add(table.concat({ ... }, " ", 1, 3)); + end + r:next(handle_target); + while not done do + r:next(handle_target); + end + + -- We should have received all expected targets, and no unexpected + -- ones: + assert.truthy(set.xor(received_targets, expected_targets):empty()); + end); + + it("balances across weights correctly #slow", function () + -- This mimics many repeated connections to 'example.com' (mock + -- records defined above), and records the port number of the + -- first target. Therefore it (should) only return priority + -- 0 records, and the input data is constructed such that the + -- last two digits of the port number represent the percentage + -- of times that record should (on average) be picked first. + + -- To prevent random test failures, we test across a handful + -- of fixed (randomly selected) seeds. + for _, seed in ipairs({ 8401877, 3943829, 7830992 }) do + math.randomseed(seed); + + local results = {}; + local function run() + local run_results = {}; + local r = resolver.new("example.com", "xmpp-server"); + local function record_target(...) + if ... == nil then + -- No more targets + return; + end + run_results = { ... }; + end + r:next(record_target); + return run_results[3]; + end + + for _ = 1, 1000 do + local port = run(); + results[port] = (results[port] or 0) + 1; + end + + local ports = {}; + for port in pairs(results) do + table.insert(ports, port); + end + table.sort(ports); + for _, port in ipairs(ports) do + --print("PORT", port, tostring((results[port]/1000) * 100).."% hits (expected "..tostring(port-5200).."%)"); + local hit_pct = (results[port]/1000) * 100; + local expected_pct = port - 5200; + --print(hit_pct, expected_pct, math.abs(hit_pct - expected_pct)); + assert.is_true(math.abs(hit_pct - expected_pct) < 5); + end + --print("---"); + end + end); +end); diff --git a/spec/util_poll_spec.lua b/spec/util_poll_spec.lua index a763be90..05318453 100644 --- a/spec/util_poll_spec.lua +++ b/spec/util_poll_spec.lua @@ -1,6 +1,35 @@ -describe("util.poll", function () - it("loads", function () - require "util.poll" +describe("util.poll", function() + local poll; + setup(function() + poll = require "util.poll"; end); + it("loads", function() + assert.is_table(poll); + assert.is_function(poll.new); + assert.is_string(poll.api); + end); + describe("new", function() + local p; + setup(function() + p = poll.new(); + end) + it("times out", function () + local fd, err = p:wait(0); + assert.falsy(fd); + assert.equal("timeout", err); + end); + it("works", function() + -- stdout should be writable, right? + assert.truthy(p:add(1, false, true)); + local fd, r, w = p:wait(1); + assert.is_number(fd); + assert.is_boolean(r); + assert.is_boolean(w); + assert.equal(1, fd); + assert.falsy(r); + assert.truthy(w); + assert.truthy(p:del(1)); + end); + end) end); diff --git a/spec/util_table_spec.lua b/spec/util_table_spec.lua index 76f54b69..a0535c08 100644 --- a/spec/util_table_spec.lua +++ b/spec/util_table_spec.lua @@ -12,6 +12,17 @@ describe("util.table", function () assert.same({ "lorem", "ipsum", "dolor", "sit", "amet", n = 5 }, u_table.pack("lorem", "ipsum", "dolor", "sit", "amet")); end); end); + + describe("move()", function () + it("works", function () + local t1 = { "apple", "banana", "carrot" }; + local t2 = { "cat", "donkey", "elephant" }; + local t3 = {}; + u_table.move(t1, 1, 3, 1, t3); + u_table.move(t2, 1, 3, 3, t3); + assert.same({ "apple", "banana", "cat", "donkey", "elephant" }, t3); + end); + end); end); diff --git a/util-src/table.c b/util-src/table.c index 9a9553fc..4bbceedb 100644 --- a/util-src/table.c +++ b/util-src/table.c @@ -1,11 +1,21 @@ #include <lua.h> #include <lauxlib.h> +#ifndef LUA_MAXINTEGER +#include <stdint.h> +#define LUA_MAXINTEGER PTRDIFF_MAX +#endif + +#if (LUA_VERSION_NUM > 501) +#define lua_equal(L, A, B) lua_compare(L, A, B, LUA_OPEQ) +#endif + static int Lcreate_table(lua_State *L) { lua_createtable(L, luaL_checkinteger(L, 1), luaL_checkinteger(L, 2)); return 1; } +/* COMPAT: w/ Lua pre-5.4 */ static int Lpack(lua_State *L) { unsigned int n_args = lua_gettop(L); lua_createtable(L, n_args, 1); @@ -20,6 +30,40 @@ static int Lpack(lua_State *L) { return 1; } +/* COMPAT: w/ Lua pre-5.4 */ +static int Lmove (lua_State *L) { + lua_Integer f = luaL_checkinteger(L, 2); + lua_Integer e = luaL_checkinteger(L, 3); + lua_Integer t = luaL_checkinteger(L, 4); + + int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ + luaL_checktype(L, 1, LUA_TTABLE); + luaL_checktype(L, tt, LUA_TTABLE); + + if (e >= f) { /* otherwise, nothing to move */ + lua_Integer n, i; + luaL_argcheck(L, f > 0 || e < LUA_MAXINTEGER + f, 3, + "too many elements to move"); + n = e - f + 1; /* number of elements to move */ + luaL_argcheck(L, t <= LUA_MAXINTEGER - n + 1, 4, + "destination wrap around"); + if (t > e || t <= f || (tt != 1 && !lua_equal(L, 1, tt))) { + for (i = 0; i < n; i++) { + lua_rawgeti(L, 1, f + i); + lua_rawseti(L, tt, t + i); + } + } else { + for (i = n - 1; i >= 0; i--) { + lua_rawgeti(L, 1, f + i); + lua_rawseti(L, tt, t + i); + } + } + } + + lua_pushvalue(L, tt); /* return destination table */ + return 1; +} + int luaopen_util_table(lua_State *L) { #if (LUA_VERSION_NUM > 501) luaL_checkversion(L); @@ -29,5 +73,7 @@ int luaopen_util_table(lua_State *L) { lua_setfield(L, -2, "create"); lua_pushcfunction(L, Lpack); lua_setfield(L, -2, "pack"); + lua_pushcfunction(L, Lmove); + lua_setfield(L, -2, "move"); return 1; } diff --git a/util/array.lua b/util/array.lua index c33a5ef1..9d438940 100644 --- a/util/array.lua +++ b/util/array.lua @@ -8,6 +8,7 @@ local t_insert, t_sort, t_remove, t_concat = table.insert, table.sort, table.remove, table.concat; +local t_move = require "util.table".move; local setmetatable = setmetatable; local getmetatable = getmetatable; @@ -137,13 +138,11 @@ function array_base.slice(outa, ina, i, j) return outa; end - for idx = 1, 1+j-i do - outa[idx] = ina[i+(idx-1)]; - end + + t_move(ina, i, j, 1, outa); if ina == outa then - for idx = 2+j-i, #outa do - outa[idx] = nil; - end + -- Clear (nil) remainder of range + t_move(ina, #outa+1, #outa*2, 2+j-i, ina); end return outa; end @@ -209,10 +208,7 @@ function array_methods:shuffle() end function array_methods:append(ina) - local len, len2 = #self, #ina; - for i = 1, len2 do - self[len+i] = ina[i]; - end + t_move(ina, 1, #ina, #self+1, self); return self; end diff --git a/util/logger.lua b/util/logger.lua index 20a5cef2..148b98dc 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -10,6 +10,7 @@ local pairs = pairs; local ipairs = ipairs; local require = require; +local t_remove = table.remove; local _ENV = nil; -- luacheck: std none @@ -78,6 +79,20 @@ local function add_simple_sink(simple_sink_function, levels) for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do add_level_sink(level, sink_function); end + return sink_function; +end + +local function remove_sink(sink_function) + local removed; + for level, sinks in pairs(level_sinks) do + for i = #sinks, 1, -1 do + if sinks[i] == sink_function then + t_remove(sinks, i); + removed = true; + end + end + end + return removed; end return { @@ -87,4 +102,5 @@ return { add_level_sink = add_level_sink; add_simple_sink = add_simple_sink; new = make_logger; + remove_sink = remove_sink; }; diff --git a/util/prosodyctl/shell.lua b/util/prosodyctl/shell.lua index bce27b94..0b1dd3f9 100644 --- a/util/prosodyctl/shell.lua +++ b/util/prosodyctl/shell.lua @@ -89,11 +89,15 @@ local function start(arg) --luacheck: ignore 212/arg local errors = 0; -- TODO This is weird, but works for now. client.events.add_handler("received", function(stanza) if stanza.name == "repl-output" or stanza.name == "repl-result" then + local dest = io.stdout; if stanza.attr.type == "error" then errors = errors + 1; - io.stderr:write(stanza:get_text(), "\n"); + dest = io.stderr; + end + if stanza.attr.eol == "0" then + dest:write(stanza:get_text()); else - print(stanza:get_text()); + dest:write(stanza:get_text(), "\n"); end end if stanza.name == "repl-result" then diff --git a/util/stanza.lua b/util/stanza.lua index a38f80b3..a14be5f0 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -21,6 +21,8 @@ local type = type; local s_gsub = string.gsub; local s_sub = string.sub; local s_find = string.find; +local t_move = table.move or require "util.table".move; +local t_create = require"util.table".create; local valid_utf8 = require "util.encodings".utf8.valid; @@ -275,25 +277,33 @@ function stanza_mt:find(path) end local function _clone(stanza, only_top) - local attr, tags = {}, {}; + local attr = {}; for k,v in pairs(stanza.attr) do attr[k] = v; end local old_namespaces, namespaces = stanza.namespaces; if old_namespaces then namespaces = {}; for k,v in pairs(old_namespaces) do namespaces[k] = v; end end - local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; + local tags, new; + if only_top then + tags = {}; + new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; + else + tags = t_create(#stanza.tags, 0); + new = t_create(#stanza, 4); + new.name = stanza.name; + new.attr = attr; + new.namespaces = namespaces; + new.tags = tags; + end + + setmetatable(new, stanza_mt); if not only_top then - for i=1,#stanza do - local child = stanza[i]; - if child.name then - child = _clone(child); - t_insert(tags, child); - end - t_insert(new, child); - end + t_move(stanza, 1, #stanza, 1, new); + t_move(stanza.tags, 1, #stanza.tags, 1, tags); + new:maptags(_clone); end - return setmetatable(new, stanza_mt); + return new; end local function clone(stanza, only_top) |