diff options
37 files changed, 1255 insertions, 273 deletions
@@ -1,3 +1,17 @@ +TRUNK +===== + +## New + +### Administration + +- Add 'watch log' command to follow live debug logs at runtime (even if disabled) + +### Networking + +- Honour 'weight' parameter during SRV record selection +- Support for RFC 8305 "Happy Eyeballs" to improve IPv4/IPv6 connectivity + 0.12.0 ====== diff --git a/GNUmakefile b/GNUmakefile index e9ec78c4..c8d2d3dd 100644 --- a/GNUmakefile +++ b/GNUmakefile @@ -71,12 +71,13 @@ install-util: util/encodings.so util/encodings.so util/pposix.so util/signal.so install-plugins: $(MKDIR) $(MODULES) - $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam + $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam $(MODULES)/mod_debug_stanzas $(INSTALL_DATA) plugins/*.lua $(MODULES) $(INSTALL_DATA) plugins/mod_pubsub/*.lua $(MODULES)/mod_pubsub $(INSTALL_DATA) plugins/adhoc/*.lua $(MODULES)/adhoc $(INSTALL_DATA) plugins/muc/*.lua $(MODULES)/muc $(INSTALL_DATA) plugins/mod_mam/*.lua $(MODULES)/mod_mam + $(INSTALL_DATA) plugins/mod_debug_stanzas/*.lua $(MODULES)/mod_debug_stanzas install-man: $(MKDIR) $(MAN)/man1 diff --git a/core/certmanager.lua b/core/certmanager.lua index 7a82c786..0c71e448 100644 --- a/core/certmanager.lua +++ b/core/certmanager.lua @@ -9,9 +9,8 @@ local ssl = require "ssl"; local configmanager = require "core.configmanager"; local log = require "util.logger".init("certmanager"); -local ssl_context = ssl.context or require "ssl.context"; local ssl_newcontext = ssl.newcontext; -local new_config = require"util.sslconfig".new; +local new_config = require"net.server".tls_builder; local stat = require "lfs".attributes; local x509 = require "util.x509"; @@ -313,10 +312,6 @@ else core_defaults.curveslist = nil; end -local path_options = { -- These we pass through resolve_path() - key = true, certificate = true, cafile = true, capath = true, dhparam = true -} - local function create_context(host, mode, ...) local cfg = new_config(); cfg:apply(core_defaults); @@ -352,34 +347,7 @@ local function create_context(host, mode, ...) if user_ssl_config.certificate and not user_ssl_config.key then return nil, "No key present in SSL/TLS configuration for "..host; end end - for option in pairs(path_options) do - if type(user_ssl_config[option]) == "string" then - user_ssl_config[option] = resolve_path(config_path, user_ssl_config[option]); - else - user_ssl_config[option] = nil; - end - end - - -- LuaSec expects dhparam to be a callback that takes two arguments. - -- We ignore those because it is mostly used for having a separate - -- set of params for EXPORT ciphers, which we don't have by default. - if type(user_ssl_config.dhparam) == "string" then - local f, err = io_open(user_ssl_config.dhparam); - if not f then return nil, "Could not open DH parameters: "..err end - local dhparam = f:read("*a"); - f:close(); - user_ssl_config.dhparam = function() return dhparam; end - end - - local ctx, err = ssl_newcontext(user_ssl_config); - - -- COMPAT Older LuaSec ignores the cipher list from the config, so we have to take care - -- of it ourselves (W/A for #x) - if ctx and user_ssl_config.ciphers then - local success; - success, err = ssl_context.setcipher(ctx, user_ssl_config.ciphers); - if not success then ctx = nil; end - end + local ctx, err = cfg:build(); if not ctx then err = err or "invalid ssl config" diff --git a/core/portmanager.lua b/core/portmanager.lua index 38c74b66..8c7dfddb 100644 --- a/core/portmanager.lua +++ b/core/portmanager.lua @@ -240,21 +240,22 @@ local function add_sni_host(host, service) log("debug", "Gathering certificates for SNI for host %s, %s service", host, service or "default"); for name, interface, port, n, active_service --luacheck: ignore 213 in active_services:iter(service, nil, nil, nil) do - if active_service.server.hosts and active_service.tls_cfg then - local config_prefix = (active_service.config_prefix or name).."_"; - if config_prefix == "_" then config_prefix = ""; end - local prefix_ssl_config = config.get(host, config_prefix.."ssl"); + if active_service.server and active_service.tls_cfg then local alternate_host = name and config.get(host, name.."_host"); if not alternate_host and name == "https" then -- TODO should this be some generic thing? e.g. in the service definition alternate_host = config.get(host, "http_host"); end local autocert = certmanager.find_host_cert(alternate_host or host); - -- luacheck: ignore 211/cfg - local ssl, err, cfg = certmanager.create_context(host, "server", prefix_ssl_config, autocert, active_service.tls_cfg); - if ssl then - active_service.server.hosts[alternate_host or host] = ssl; - else + local manualcert = active_service.tls_cfg; + local certificate = (autocert and autocert.certificate) or manualcert.certificate; + local key = (autocert and autocert.key) or manualcert.key; + local ok, err = active_service.server:sslctx():set_sni_host( + host, + certificate, + key + ); + if not ok then log("error", "Error creating TLS context for SNI host %s: %s", host, err); end end @@ -277,7 +278,7 @@ prosody.events.add_handler("host-deactivated", function (host) for name, interface, port, n, active_service --luacheck: ignore 213 in active_services:iter(nil, nil, nil, nil) do if active_service.tls_cfg then - active_service.server.hosts[host] = nil; + active_service.server:sslctx():remove_sni_host(host) end end end); @@ -73,12 +73,13 @@ install-util: util/encodings.so util/encodings.so util/pposix.so util/signal.so install-plugins: $(MKDIR) $(MODULES) - $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam + $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam $(MODULES)/mod_debug_stanzas $(INSTALL_DATA) plugins/*.lua $(MODULES) $(INSTALL_DATA) plugins/mod_pubsub/*.lua $(MODULES)/mod_pubsub $(INSTALL_DATA) plugins/adhoc/*.lua $(MODULES)/adhoc $(INSTALL_DATA) plugins/muc/*.lua $(MODULES)/muc $(INSTALL_DATA) plugins/mod_mam/*.lua $(MODULES)/mod_mam + $(INSTALL_DATA) plugins/mod_debug_stanzas/*.lua $(MODULES)/mod_debug_stanzas install-man: $(MKDIR) $(MAN)/man1 diff --git a/net/connect.lua b/net/connect.lua index 4b602be4..d85afcff 100644 --- a/net/connect.lua +++ b/net/connect.lua @@ -1,6 +1,7 @@ local server = require "net.server"; local log = require "util.logger".init("net.connect"); local new_id = require "util.id".short; +local timer = require "util.timer"; -- TODO #1246 Happy Eyeballs -- FIXME RFC 6724 @@ -28,16 +29,17 @@ local pending_connection_listeners = {}; local function attempt_connection(p) p:log("debug", "Checking for targets..."); - if p.conn then - pending_connections_map[p.conn] = nil; - p.conn = nil; - end - p.target_resolver:next(function (conn_type, ip, port, extra) + p.target_resolver:next(function (conn_type, ip, port, extra, more_targets_available) if not conn_type then -- No more targets to try p:log("debug", "No more connection targets to try", p.target_resolver.last_error); - if p.listeners.onfail then - p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service"); + if next(p.conns) == nil then + p:log("debug", "No more targets, no pending connections. Connection failed."); + if p.listeners.onfail then + p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service"); + end + else + p:log("debug", "One or more connection attempts are still pending. Waiting for now."); end return; end @@ -49,8 +51,16 @@ local function attempt_connection(p) p.last_error = err or "unknown reason"; return attempt_connection(p); end - p.conn = conn; + p.conns[conn] = true; pending_connections_map[conn] = p; + if more_targets_available then + timer.add_task(0.250, function () + if not p.connected then + p:log("debug", "Still not connected, making parallel connection attempt..."); + attempt_connection(p); + end + end); + end end); end @@ -62,6 +72,13 @@ function pending_connection_listeners.onconnect(conn) return; end pending_connections_map[conn] = nil; + if p.connected then + -- We already succeeded in connecting + p.conns[conn] = nil; + conn:close(); + return; + end + p.connected = true; p:log("debug", "Successfully connected"); conn:setlistener(p.listeners, p.data); return p.listeners.onconnect(conn); @@ -73,9 +90,18 @@ function pending_connection_listeners.ondisconnect(conn, reason) log("warn", "Failed connection, but unexpected!"); return; end + p.conns[conn] = nil; + pending_connections_map[conn] = nil; p.last_error = reason or "unknown reason"; p:log("debug", "Connection attempt failed: %s", p.last_error); - attempt_connection(p); + if p.connected then + p:log("debug", "Connection already established, ignoring failure"); + elseif next(p.conns) == nil then + p:log("debug", "No pending connection attempts, and not yet connected"); + attempt_connection(p); + else + p:log("debug", "Other attempts are still pending, ignoring failure"); + end end local function connect(target_resolver, listeners, options, data) @@ -85,6 +111,7 @@ local function connect(target_resolver, listeners, options, data) listeners = assert(listeners); options = options or {}; data = data; + conns = {}; }, pending_connection_mt); p:log("debug", "Starting connection process"); diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua index 305bce76..15338ff4 100644 --- a/net/resolvers/basic.lua +++ b/net/resolvers/basic.lua @@ -2,13 +2,59 @@ local adns = require "net.adns"; local inet_pton = require "util.net".pton; local inet_ntop = require "util.net".ntop; local idna_to_ascii = require "util.encodings".idna.to_ascii; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 +local promise = require "util.promise"; +local t_move = require "util.table".move; local methods = {}; local resolver_mt = { __index = methods }; -- FIXME RFC 6724 +local function do_dns_lookup(self, dns_resolver, record_type, name) + return promise.new(function (resolve, reject) + local ipv = (record_type == "A" and "4") or (record_type == "AAAA" and "6") or nil; + if ipv and self.extra["use_ipv"..ipv] == false then + return reject(("IPv%s disabled - %s lookup skipped"):format(ipv, record_type)); + elseif record_type == "TLSA" and self.extra.use_dane ~= true then + return reject("DANE disabled - TLSA lookup skipped"); + end + dns_resolver:lookup(function (answer, err) + if not answer then + return reject(err); + elseif answer.bogus then + return reject(("Validation error in %s lookup"):format(record_type)); + elseif answer.status and #answer == 0 then + return reject(("%s in %s lookup"):format(answer.status, record_type)); + end + + local targets = { secure = answer.secure }; + for _, record in ipairs(answer) do + if ipv then + table.insert(targets, { self.conn_type..ipv, record[record_type:lower()], self.port, self.extra }); + else + table.insert(targets, record[record_type:lower()]); + end + end + return resolve(targets); + end, name, record_type, "IN"); + end); +end + +local function merge_targets(ipv4_targets, ipv6_targets) + local result = { secure = ipv4_targets.secure and ipv6_targets.secure }; + local common_length = math.min(#ipv4_targets, #ipv6_targets); + for i = 1, common_length do + table.insert(result, ipv6_targets[i]); + table.insert(result, ipv4_targets[i]); + end + if common_length < #ipv4_targets then + t_move(ipv4_targets, common_length+1, #ipv4_targets, common_length+1, result); + elseif common_length < #ipv6_targets then + t_move(ipv6_targets, common_length+1, #ipv6_targets, common_length+1, result); + end + return result; +end + -- Find the next target to connect to, and -- pass it to cb() function methods:next(cb) @@ -18,7 +64,7 @@ function methods:next(cb) return; end local next_target = table.remove(self.targets, 1); - cb(unpack(next_target, 1, 4)); + cb(next_target[1], next_target[2], next_target[3], next_target[4], not not self.targets[1]); return; end @@ -28,91 +74,45 @@ function methods:next(cb) return; end - local secure = true; - local tlsa = {}; - local targets = {}; - local n = 3; - local function ready() - n = n - 1; - if n > 0 then return; end - self.targets = targets; + -- Resolve DNS to target list + local dns_resolver = adns.resolver(); + + local dns_lookups = { + ipv4 = do_dns_lookup(self, dns_resolver, "A", self.hostname); + ipv6 = do_dns_lookup(self, dns_resolver, "AAAA", self.hostname); + tlsa = do_dns_lookup(self, dns_resolver, "TLSA", ("_%d._%s.%s"):format(self.port, self.conn_type, self.hostname)); + }; + + promise.all_settled(dns_lookups):next(function (dns_results) + -- Combine targets, assign to self.targets, self:next(cb) + local have_ipv4 = dns_results.ipv4.status == "fulfilled"; + local have_ipv6 = dns_results.ipv6.status == "fulfilled"; + + if have_ipv4 and have_ipv6 then + self.targets = merge_targets(dns_results.ipv4.value, dns_results.ipv6.value); + elseif have_ipv4 then + self.targets = dns_results.ipv4.value; + elseif have_ipv6 then + self.targets = dns_results.ipv6.value; + else + self.targets = {}; + end + if self.extra and self.extra.use_dane then - if secure and tlsa[1] then - self.extra.tlsa = tlsa; + if self.targets.secure and dns_results.tlsa.status == "fulfilled" then + self.extra.tlsa = dns_results.tlsa.value; self.extra.dane_hostname = self.hostname; else self.extra.tlsa = nil; self.extra.dane_hostname = nil; end end - self:next(cb); - end - -- Resolve DNS to target list - local dns_resolver = adns.resolver(); - - if not self.extra or self.extra.use_ipv4 ~= false then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra }); - end - if answer.bogus then - self.last_error = "Validation error in A lookup"; - elseif answer.status then - self.last_error = answer.status .. " in A lookup"; - end - else - self.last_error = err; - end - ready(); - end, self.hostname, "A", "IN"); - else - ready(); - end - - if not self.extra or self.extra.use_ipv6 ~= false then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra }); - end - if answer.bogus then - self.last_error = "Validation error in AAAA lookup"; - elseif answer.status then - self.last_error = answer.status .. " in AAAA lookup"; - end - else - self.last_error = err; - end - ready(); - end, self.hostname, "AAAA", "IN"); - else - ready(); - end - - if self.extra and self.extra.use_dane == true then - dns_resolver:lookup(function (answer, err) - if answer then - secure = secure and answer.secure; - for _, record in ipairs(answer) do - table.insert(tlsa, record.tlsa); - end - if answer.bogus then - self.last_error = "Validation error in TLSA lookup"; - elseif answer.status then - self.last_error = answer.status .. " in TLSA lookup"; - end - else - self.last_error = err; - end - ready(); - end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN"); - else - ready(); - end + self:next(cb); + end):catch(function (err) + self.last_error = err; + self.targets = {}; + end); end local function new(hostname, port, conn_type, extra) @@ -137,7 +137,7 @@ local function new(hostname, port, conn_type, extra) hostname = ascii_host; port = port; conn_type = conn_type; - extra = extra; + extra = extra or {}; targets = targets; }, resolver_mt); end diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua index 3810cac8..a7ce76a3 100644 --- a/net/resolvers/service.lua +++ b/net/resolvers/service.lua @@ -2,23 +2,78 @@ local adns = require "net.adns"; local basic = require "net.resolvers.basic"; local inet_pton = require "util.net".pton; local idna_to_ascii = require "util.encodings".idna.to_ascii; -local unpack = table.unpack or unpack; -- luacheck: ignore 113 local methods = {}; local resolver_mt = { __index = methods }; +local function new_target_selector(rrset) + local rr_count = rrset and #rrset; + if not rr_count or rr_count == 0 then + rrset = nil; + else + table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end); + end + local rrset_pos = 1; + local priority_bucket, bucket_total_weight, bucket_len, bucket_used; + return function () + if not rrset then return; end + + if not priority_bucket or bucket_used >= bucket_len then + if rrset_pos > rr_count then return; end -- Used up all records + + -- Going to start on a new priority now. Gather up all the next + -- records with the same priority and add them to priority_bucket + priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0; + local current_priority; + repeat + local curr_record = rrset[rrset_pos].srv; + if not current_priority then + current_priority = curr_record.priority; + elseif current_priority ~= curr_record.priority then + break; + end + table.insert(priority_bucket, curr_record); + bucket_total_weight = bucket_total_weight + curr_record.weight; + bucket_len = bucket_len + 1; + rrset_pos = rrset_pos + 1; + until rrset_pos > rr_count; + end + + bucket_used = bucket_used + 1; + local n, running_total = math.random(0, bucket_total_weight), 0; + local target_record; + for i = 1, bucket_len do + local candidate = priority_bucket[i]; + if candidate then + running_total = running_total + candidate.weight; + if running_total >= n then + target_record = candidate; + bucket_total_weight = bucket_total_weight - candidate.weight; + priority_bucket[i] = nil; + break; + end + end + end + return target_record; + end; +end + -- Find the next target to connect to, and -- pass it to cb() function methods:next(cb) - if self.targets then - if not self.resolver then - if #self.targets == 0 then + if self.resolver or self._get_next_target then + if not self.resolver then -- Do we have a basic resolver currently? + -- We don't, so fetch a new SRV target, create a new basic resolver for it + local next_srv_target = self._get_next_target and self._get_next_target(); + if not next_srv_target then + -- No more SRV targets left cb(nil); return; end - local next_target = table.remove(self.targets, 1); - self.resolver = basic.new(unpack(next_target, 1, 4)); + -- Create a new basic resolver for this SRV target + self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra); end + -- Look up the next (basic) target from the current target's resolver self.resolver:next(function (...) if self.resolver then self.last_error = self.resolver.last_error; @@ -31,6 +86,9 @@ function methods:next(cb) end end); return; + elseif self.in_progress then + cb(nil); + return; end if not self.hostname then @@ -39,9 +97,9 @@ function methods:next(cb) return; end - local targets = {}; + self.in_progress = true; + local function ready() - self.targets = targets; self:next(cb); end @@ -63,7 +121,7 @@ function methods:next(cb) if #answer == 0 then if self.extra and self.extra.default_port then - table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra }); + self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra); else self.last_error = "zero SRV records found"; end @@ -77,10 +135,7 @@ function methods:next(cb) return; end - table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end); - for _, record in ipairs(answer) do - table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra }); - end + self._get_next_target = new_target_selector(answer); else self.last_error = err; end diff --git a/net/server.lua b/net/server.lua index 0696fd52..72272bef 100644 --- a/net/server.lua +++ b/net/server.lua @@ -118,6 +118,13 @@ if prosody and set_config then prosody.events.add_handler("config-reloaded", load_config); end +local tls_builder = server.tls_builder; +-- resolving the basedir here avoids util.sslconfig depending on +-- prosody.paths.config +function server.tls_builder() + return tls_builder(prosody.paths.config or "") +end + -- require "net.server" shall now forever return this, -- ie. server_select or server_event as chosen above. return server; diff --git a/net/server_epoll.lua b/net/server_epoll.lua index fa275d71..8e75e072 100644 --- a/net/server_epoll.lua +++ b/net/server_epoll.lua @@ -18,7 +18,6 @@ local traceback = debug.traceback; local logger = require "util.logger"; local log = logger.init("server_epoll"); local socket = require "socket"; -local luasec = require "ssl"; local realtime = require "util.time".now; local monotonic = require "util.time".monotonic; local indexedbheap = require "util.indexedbheap"; @@ -28,6 +27,8 @@ local inet_pton = inet.pton; local _SOCKETINVALID = socket._SOCKETINVALID or -1; local new_id = require "util.id".short; local xpcall = require "util.xpcall".xpcall; +local sslconfig = require "util.sslconfig"; +local tls_impl = require "net.tls_luasec"; local poller = require "util.poll" local EEXIST = poller.EEXIST; @@ -614,6 +615,30 @@ function interface:set_sslctx(sslctx) self._sslctx = sslctx; end +function interface:sslctx() + return self.tls_ctx +end + +function interface:ssl_info() + local sock = self.conn; + return sock.info and sock:info(); +end + +function interface:ssl_peercertificate() + local sock = self.conn; + return sock.getpeercertificate and sock:getpeercertificate(); +end + +function interface:ssl_peerverification() + local sock = self.conn; + return sock.getpeerverification and sock:getpeerverification(); +end + +function interface:ssl_peerfinished() + local sock = self.conn; + return sock.getpeerfinished and sock:getpeerfinished(); +end + function interface:starttls(tls_ctx) if tls_ctx then self.tls_ctx = tls_ctx; end self.starttls = false; @@ -641,11 +666,7 @@ function interface:inittls(tls_ctx, now) self.starttls = false; self:debug("Starting TLS now"); self:updatenames(); -- Can't getpeer/sockname after wrap() - local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx); - if not ok then - conn, err = ok, conn; - self:debug("Failed to initialize TLS: %s", err); - end + local conn, err = self.tls_ctx:wrap(self.conn); if not conn then self:on("disconnect", err); self:destroy(); @@ -656,8 +677,8 @@ function interface:inittls(tls_ctx, now) if conn.sni then if self.servername then conn:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - conn:sni(self._server.hosts, true); + elseif next(self.tls_ctx._sni_contexts) ~= nil then + conn:sni(self.tls_ctx._sni_contexts, true); end end if self.extra and self.extra.tlsa and conn.settlsa then @@ -1085,6 +1106,10 @@ return { cfg = setmetatable(newconfig, default_config); end; + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, + -- libevent emulation event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; addevent = function (fd, mode, callback) diff --git a/net/server_event.lua b/net/server_event.lua index c30181b8..313ba981 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -47,11 +47,13 @@ local s_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield -local has_luasec, ssl = pcall ( require , "ssl" ) +local has_luasec = pcall ( require , "ssl" ) local socket = require "socket" local levent = require "luaevent.core" local inet = require "util.net"; local inet_pton = inet.pton; +local sslconfig = require "util.sslconfig"; +local tls_impl = require "net.tls_luasec"; local socket_gettime = socket.gettime @@ -153,7 +155,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed _ = self.eventwrite and self.eventwrite:close( ) self.eventread, self.eventwrite = nil, nil local err - self.conn, err = ssl.wrap( self.conn, self._sslctx ) + self.conn, err = self._sslctx:wrap(self.conn) if err then self.fatalerror = err self.conn = nil -- cannot be used anymore @@ -168,8 +170,8 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed if self.conn.sni then if self.servername then self.conn:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - self.conn:sni(self._server.hosts, true); + elseif next(self._sslctx._sni_contexts) ~= nil then + self.conn:sni(self._sslctx._sni_contexts, true); end end @@ -274,6 +276,26 @@ function interface_mt:pause() return self:_lock(self.nointerface, true, self.nowriting); end +function interface_mt:sslctx() + return self._sslctx +end + +function interface_mt:ssl_info() + return self.conn.info and self.conn:info() +end + +function interface_mt:ssl_peercertificate() + return self.conn.getpeercertificate and self.conn:getpeercertificate() +end + +function interface_mt:ssl_peerverification() + return self.conn.getpeerverification and self.conn:getpeerverification() +end + +function interface_mt:ssl_peerfinished() + return self.conn.getpeerfinished and self.conn:getpeerfinished() +end + function interface_mt:resume() self:_lock(self.nointerface, false, self.nowriting); if self.readcallback and not self.eventread then @@ -924,6 +946,10 @@ return { add_task = add_task, watchfd = watchfd, + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, + __NAME = SCRIPT_NAME, __DATE = LAST_MODIFIED, __AUTHOR = SCRIPT_AUTHOR, diff --git a/net/server_select.lua b/net/server_select.lua index eea850ce..80f5f590 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -47,15 +47,15 @@ local coroutine_yield = coroutine.yield --// extern libs //-- -local has_luasec, luasec = pcall ( require , "ssl" ) local luasocket = use "socket" or require "socket" local luasocket_gettime = luasocket.gettime local inet = require "util.net"; local inet_pton = inet.pton; +local sslconfig = require "util.sslconfig"; +local has_luasec, tls_impl = pcall(require, "net.tls_luasec"); --// extern lib methods //-- -local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind local socket_select = luasocket.select @@ -359,6 +359,18 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.sslctx = function ( ) return sslctx end + handler.ssl_info = function( ) + return socket.info and socket:info() + end + handler.ssl_peercertificate = function( ) + return socket.getpeercertificate and socket:getpeercertificate() + end + handler.ssl_peerverification = function( ) + return socket.getpeerverification and socket:getpeerverification() + end + handler.ssl_peerfinished = function( ) + return socket.getpeerfinished and socket:getpeerfinished() + end handler.send = function( _, data, i, j ) return send( socket, data, i, j ) end @@ -652,7 +664,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end out_put( "server.lua: attempting to start tls on " .. tostring( socket ) ) local oldsocket, err = socket - socket, err = ssl_wrap( socket, sslctx ) -- wrap socket + socket, err = sslctx:wrap(socket) -- wrap socket if not socket then out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") ) @@ -662,8 +674,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport if socket.sni then if self.servername then socket:sni(self.servername); - elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then - socket:sni(self.server().hosts, true); + elseif next(sslctx._sni_contexts) ~= nil then + socket:sni(sslctx._sni_contexts, true); end end @@ -1169,4 +1181,8 @@ return { removeserver = removeserver, get_backend = get_backend, changesettings = changesettings, + + tls_builder = function(basedir) + return sslconfig._new(tls_impl.new_context, basedir) + end, } diff --git a/net/tls_luasec.lua b/net/tls_luasec.lua new file mode 100644 index 00000000..2bedb5ab --- /dev/null +++ b/net/tls_luasec.lua @@ -0,0 +1,89 @@ +-- Prosody IM +-- Copyright (C) 2021 Prosody folks +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +--[[ +This file provides a shim abstraction over LuaSec, consolidating some code +which was previously spread between net.server backends, portmanager and +certmanager. + +The goal is to provide a more or less well-defined API on top of LuaSec which +abstracts away some of the things which are not needed and simplifies usage of +commonly used things (such as SNI contexts). Eventually, network backends +which do not rely on LuaSocket+LuaSec should be able to provide *this* API +instead of having to mimic LuaSec. +]] +local ssl = require "ssl"; +local ssl_newcontext = ssl.newcontext; +local ssl_context = ssl.context or require "ssl.context"; +local io_open = io.open; + +local context_api = {}; +local context_mt = {__index = context_api}; + +function context_api:set_sni_host(host, cert, key) + local ctx, err = self._builder:clone():apply({ + certificate = cert, + key = key, + }):build(); + if not ctx then + return false, err + end + + self._sni_contexts[host] = ctx._inner + + return true, nil +end + +function context_api:remove_sni_host(host) + self._sni_contexts[host] = nil +end + +function context_api:wrap(sock) + local ok, conn, err = pcall(ssl.wrap, sock, self._inner); + if not ok then + return nil, err + end + return conn, nil +end + +local function new_context(cfg, builder) + -- LuaSec expects dhparam to be a callback that takes two arguments. + -- We ignore those because it is mostly used for having a separate + -- set of params for EXPORT ciphers, which we don't have by default. + if type(cfg.dhparam) == "string" then + local f, err = io_open(cfg.dhparam); + if not f then return nil, "Could not open DH parameters: "..err end + local dhparam = f:read("*a"); + f:close(); + cfg.dhparam = function() return dhparam; end + end + + local inner, err = ssl_newcontext(cfg); + if not inner then + return nil, err + end + + -- COMPAT Older LuaSec ignores the cipher list from the config, so we have to take care + -- of it ourselves (W/A for #x) + if inner and cfg.ciphers then + local success; + success, err = ssl_context.setcipher(inner, cfg.ciphers); + if not success then + return nil, err + end + end + + return setmetatable({ + _inner = inner, + _builder = builder, + _sni_contexts = {}, + }, context_mt), nil +end + +return { + new_context = new_context, +}; diff --git a/plugins/adhoc/adhoc.lib.lua b/plugins/adhoc/adhoc.lib.lua index 4cf6911d..eb91f252 100644 --- a/plugins/adhoc/adhoc.lib.lua +++ b/plugins/adhoc/adhoc.lib.lua @@ -34,6 +34,8 @@ function _M.handle_cmd(command, origin, stanza) local cmdtag = stanza.tags[1] local sessionid = cmdtag.attr.sessionid or uuid.generate(); local dataIn = { + origin = origin; + stanza = stanza; to = stanza.attr.to; from = stanza.attr.from; action = cmdtag.attr.action or "execute"; diff --git a/plugins/adhoc/mod_adhoc.lua b/plugins/adhoc/mod_adhoc.lua index 09a72075..9d6ff77a 100644 --- a/plugins/adhoc/mod_adhoc.lua +++ b/plugins/adhoc/mod_adhoc.lua @@ -79,12 +79,12 @@ module:hook("iq-set/host/"..xmlns_cmd..":command", function (event) or (command.permission == "global_admin" and not global_admin) or (command.permission == "local_user" and hostname ~= module.host) then origin.send(st.error_reply(stanza, "auth", "forbidden", "You don't have permission to execute this command"):up() - :add_child(commands[node]:cmdtag("canceled") + :add_child(command:cmdtag("canceled") :tag("note", {type="error"}):text("You don't have permission to execute this command"))); return true end -- User has permission now execute the command - adhoc_handle_cmd(commands[node], origin, stanza); + adhoc_handle_cmd(command, origin, stanza); return true; end end, 500); diff --git a/plugins/mod_admin_shell.lua b/plugins/mod_admin_shell.lua index 94530cf1..363ad5c6 100644 --- a/plugins/mod_admin_shell.lua +++ b/plugins/mod_admin_shell.lua @@ -36,6 +36,7 @@ local serialization = require "util.serialization"; local serialize_config = serialization.new ({ fatal = false, unquoted = true}); local time = require "util.time"; local promise = require "util.promise"; +local logger = require "util.logger"; local t_insert = table.insert; local t_concat = table.concat; @@ -83,8 +84,8 @@ function runner_callbacks:error(err) self.data.print("Error: "..tostring(err)); end -local function send_repl_output(session, line) - return session.send(st.stanza("repl-output"):text(tostring(line))); +local function send_repl_output(session, line, attr) + return session.send(st.stanza("repl-output", attr):text(tostring(line))); end function console:new_session(admin_session) @@ -99,8 +100,14 @@ function console:new_session(admin_session) end return send_repl_output(admin_session, table.concat(t, "\t")); end; + write = function (t) + return send_repl_output(admin_session, t, { eol = "0" }); + end; serialize = tostring; disconnect = function () admin_session:close(); end; + is_connected = function () + return not not admin_session.conn; + end }; session.env = setmetatable({}, default_env_mt); @@ -800,9 +807,7 @@ available_columns = { mapper = function(conn, session) if not session.secure then return "insecure"; end if not conn or not conn:ssl() then return "secure" end - local sock = conn and conn:socket(); - if not sock then return "secure"; end - local tls_info = sock.info and sock:info(); + local tls_info = conn.ssl_info and conn:ssl_info(); return tls_info and tls_info.protocol or "secure"; end; }; @@ -812,8 +817,7 @@ available_columns = { width = 30; key = "conn"; mapper = function(conn) - local sock = conn:socket(); - local info = sock and sock.info and sock:info(); + local info = conn and conn.ssl_info and conn:ssl_info(); if info then return info.cipher end end; }; @@ -1583,6 +1587,60 @@ function def_env.http:list(hosts) return true; end +def_env.watch = {}; + +function def_env.watch:log() + local writing = false; + local sink = logger.add_simple_sink(function (source, level, message) + if writing then return; end + writing = true; + self.session.print(source, level, message); + writing = false; + end); + + while self.session.is_connected() do + async.sleep(3); + end + if not logger.remove_sink(sink) then + module:log("warn", "Unable to remove watch:log() sink"); + end +end + +local stanza_watchers = module:require("mod_debug_stanzas/watcher"); +function def_env.watch:stanzas(target_spec, filter_spec) + local function handler(event_type, stanza, session) + if stanza then + if event_type == "sent" then + self.session.print(("\n<!-- sent to %s -->"):format(session.id)); + elseif event_type == "received" then + self.session.print(("\n<!-- received from %s -->"):format(session.id)); + else + self.session.print(("\n<!-- %s (%s) -->"):format(event_type, session.id)); + end + self.session.print(stanza); + elseif session then + self.session.print("\n<!-- session "..session.id.." "..event_type.." -->"); + elseif event_type then + self.session.print("\n<!-- "..event_type.." -->"); + end + end + + stanza_watchers.add({ + target_spec = { + jid = target_spec; + }; + filter_spec = filter_spec and { + with_jid = filter_spec; + }; + }, handler); + + while self.session.is_connected() do + async.sleep(3); + end + + stanza_watchers.remove(handler); +end + def_env.debug = {}; function def_env.debug:logevents(host) @@ -1926,6 +1984,10 @@ function def_env.stats:show(name_filter) end +function module.unload() + stanza_watchers.cleanup(); +end + ------------- diff --git a/plugins/mod_c2s.lua b/plugins/mod_c2s.lua index c8f54fa7..8c0844ae 100644 --- a/plugins/mod_c2s.lua +++ b/plugins/mod_c2s.lua @@ -117,8 +117,7 @@ function stream_callbacks._streamopened(session, attr) session.secure = true; session.encrypted = true; - local sock = session.conn:socket(); - local info = sock.info and sock:info(); + local info = session.conn:ssl_info(); if type(info) == "table" then (session.log or log)("info", "Stream encrypted (%s with %s)", info.protocol, info.cipher); session.compressed = info.compression; @@ -295,8 +294,7 @@ function listener.onconnect(conn) session.encrypted = true; -- Check if TLS compression is used - local sock = conn:socket(); - local info = sock.info and sock:info(); + local info = conn:ssl_info(); if type(info) == "table" then (session.log or log)("info", "Stream encrypted (%s with %s)", info.protocol, info.cipher); session.compressed = info.compression; diff --git a/plugins/mod_debug_stanzas/watcher.lib.lua b/plugins/mod_debug_stanzas/watcher.lib.lua new file mode 100644 index 00000000..e21fc946 --- /dev/null +++ b/plugins/mod_debug_stanzas/watcher.lib.lua @@ -0,0 +1,220 @@ +local filters = require "util.filters"; +local jid = require "util.jid"; +local set = require "util.set"; + +local client_watchers = {}; + +-- active_filters[session] = { +-- filter_func = filter_func; +-- downstream = { cb1, cb2, ... }; +-- } +local active_filters = {}; + +local function subscribe_session_stanzas(session, handler, reason) + if active_filters[session] then + table.insert(active_filters[session].downstream, handler); + if reason then + handler(reason, nil, session); + end + return; + end + local downstream = { handler }; + active_filters[session] = { + filter_in = function (stanza) + module:log("debug", "NOTIFY WATCHER %d", #downstream); + for i = 1, #downstream do + downstream[i]("received", stanza, session); + end + return stanza; + end; + filter_out = function (stanza) + module:log("debug", "NOTIFY WATCHER %d", #downstream); + for i = 1, #downstream do + downstream[i]("sent", stanza, session); + end + return stanza; + end; + downstream = downstream; + }; + filters.add_filter(session, "stanzas/in", active_filters[session].filter_in); + filters.add_filter(session, "stanzas/out", active_filters[session].filter_out); + if reason then + handler(reason, nil, session); + end +end + +local function unsubscribe_session_stanzas(session, handler, reason) + local active_filter = active_filters[session]; + if not active_filter then + return; + end + for i = #active_filter.downstream, 1, -1 do + if active_filter.downstream[i] == handler then + table.remove(active_filter.downstream, i); + if reason then + handler(reason, nil, session); + end + end + end + if #active_filter.downstream == 0 then + filters.remove_filter(session, "stanzas/in", active_filter.filter_in); + filters.remove_filter(session, "stanzas/out", active_filter.filter_out); + end + active_filters[session] = nil; +end + +local function unsubscribe_all_from_session(session, reason) + local active_filter = active_filters[session]; + if not active_filter then + return; + end + for i = #active_filter.downstream, 1, -1 do + local handler = table.remove(active_filter.downstream, i); + if reason then + handler(reason, nil, session); + end + end + filters.remove_filter(session, "stanzas/in", active_filter.filter_in); + filters.remove_filter(session, "stanzas/out", active_filter.filter_out); + active_filters[session] = nil; +end + +local function unsubscribe_handler_from_all(handler, reason) + for session in pairs(active_filters) do + unsubscribe_session_stanzas(session, handler, reason); + end +end + +local s2s_watchers = {}; + +module:hook("s2sin-established", function (event) + for _, watcher in ipairs(s2s_watchers) do + if watcher.target_spec == event.session.from_host then + subscribe_session_stanzas(event.session, watcher.handler, "opened"); + end + end +end); + +module:hook("s2sout-established", function (event) + for _, watcher in ipairs(s2s_watchers) do + if watcher.target_spec == event.session.to_host then + subscribe_session_stanzas(event.session, watcher.handler, "opened"); + end + end +end); + +module:hook("s2s-closed", function (event) + unsubscribe_all_from_session(event.session, "closed"); +end); + +local watched_hosts = set.new(); + +local handler_map = setmetatable({}, { __mode = "kv" }); + +local function add_stanza_watcher(spec, orig_handler) + local function filtering_handler(event_type, stanza, session) + if stanza and spec.filter_spec then + if spec.filter_spec.with_jid then + if event_type == "sent" and (not stanza.attr.from or not jid.compare(stanza.attr.from, spec.filter_spec.with_jid)) then + return; + elseif event_type == "received" and (not stanza.attr.to or not jid.compare(stanza.attr.to, spec.filter_spec.with_jid)) then + return; + end + end + end + return orig_handler(event_type, stanza, session); + end + handler_map[orig_handler] = filtering_handler; + if spec.target_spec.jid then + local target_is_remote_host = not jid.node(spec.target_spec.jid) and not prosody.hosts[spec.target_spec.jid]; + + if target_is_remote_host then + -- Watch s2s sessions + table.insert(s2s_watchers, { + target_spec = spec.target_spec.jid; + handler = filtering_handler; + orig_handler = orig_handler; + }); + + -- Scan existing s2sin for matches + for session in pairs(prosody.incoming_s2s) do + if spec.target_spec.jid == session.from_host then + subscribe_session_stanzas(session, filtering_handler, "attached"); + end + end + -- Scan existing s2sout for matches + for local_host, local_session in pairs(prosody.hosts) do --luacheck: ignore 213/local_host + for remote_host, remote_session in pairs(local_session.s2sout) do + if spec.target_spec.jid == remote_host then + subscribe_session_stanzas(remote_session, filtering_handler, "attached"); + end + end + end + else + table.insert(client_watchers, { + target_spec = spec.target_spec.jid; + handler = filtering_handler; + orig_handler = orig_handler; + }); + local host = jid.host(spec.target_spec.jid); + if not watched_hosts:contains(host) and prosody.hosts[host] then + module:context(host):hook("resource-bind", function (event) + for _, watcher in ipairs(client_watchers) do + module:log("debug", "NEW CLIENT: %s vs %s", event.session.full_jid, watcher.target_spec); + if jid.compare(event.session.full_jid, watcher.target_spec) then + module:log("debug", "MATCH"); + subscribe_session_stanzas(event.session, watcher.handler, "opened"); + else + module:log("debug", "NO MATCH"); + end + end + end); + + module:context(host):hook("resource-unbind", function (event) + unsubscribe_all_from_session(event.session, "closed"); + end); + + watched_hosts:add(host); + end + for full_jid, session in pairs(prosody.full_sessions) do + if jid.compare(full_jid, spec.target_spec.jid) then + subscribe_session_stanzas(session, filtering_handler, "attached"); + end + end + end + else + error("No recognized target selector"); + end +end + +local function remove_stanza_watcher(orig_handler) + local handler = handler_map[orig_handler]; + unsubscribe_handler_from_all(handler, "detached"); + handler_map[orig_handler] = nil; + + for i = #client_watchers, 1, -1 do + if client_watchers[i].orig_handler == orig_handler then + table.remove(client_watchers, i); + end + end + + for i = #s2s_watchers, 1, -1 do + if s2s_watchers[i].orig_handler == orig_handler then + table.remove(s2s_watchers, i); + end + end +end + +local function cleanup(reason) + client_watchers = {}; + s2s_watchers = {}; + for session in pairs(active_filters) do + unsubscribe_all_from_session(session, reason or "cancelled"); + end +end + +return { + add = add_stanza_watcher; + remove = remove_stanza_watcher; + cleanup = cleanup; +}; diff --git a/plugins/mod_s2s.lua b/plugins/mod_s2s.lua index e810c6cd..dd585ac7 100644 --- a/plugins/mod_s2s.lua +++ b/plugins/mod_s2s.lua @@ -146,17 +146,17 @@ local function bounce_sendq(session, reason) elseif type(reason) == "string" then reason_text = reason; end - for i, data in ipairs(sendq) do - local reply = data[2]; - if reply and not(reply.attr.xmlns) and bouncy_stanzas[reply.name] then - reply.attr.type = "error"; - reply:tag("error", {type = error_type, by = session.from_host}) - :tag(condition, {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}):up(); - if reason_text then - reply:tag("text", {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}) - :text("Server-to-server connection failed: "..reason_text):up(); - end + for i, stanza in ipairs(sendq) do + if not stanza.attr.xmlns and bouncy_stanzas[stanza.name] and stanza.attr.type ~= "error" and stanza.attr.type ~= "result" then + local reply = st.error_reply( + stanza, + error_type, + condition, + reason_text and ("Server-to-server connection failed: "..reason_text) or nil + ); core_process_stanza(dummy, reply); + else + (session.log or log)("debug", "Not eligible for bouncing, discarding %s", stanza:top_tag()); end sendq[i] = nil; end @@ -182,15 +182,11 @@ function route_to_existing_session(event) (host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host); -- Queue stanza until we are able to send it - local queued_item = { - tostring(stanza), - stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza); - }; if host.sendq then - t_insert(host.sendq, queued_item); + t_insert(host.sendq, st.clone(stanza)); else -- luacheck: ignore 122 - host.sendq = { queued_item }; + host.sendq = { st.clone(stanza) }; end host.log("debug", "stanza [%s] queued ", stanza.name); return true; @@ -215,7 +211,7 @@ function route_to_new_session(event) -- Store in buffer host_session.bounce_sendq = bounce_sendq; - host_session.sendq = { {tostring(stanza), stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza)} }; + host_session.sendq = { st.clone(stanza) }; log("debug", "stanza [%s] queued until connection complete", stanza.name); -- FIXME Cleaner solution to passing extra data from resolvers to net.server -- This mt-clone allows resolvers to add extra data, currently used for DANE TLSA records @@ -324,8 +320,8 @@ function mark_connected(session) if sendq then session.log("debug", "sending %d queued stanzas across new outgoing connection to %s", #sendq, session.to_host); local send = session.sends2s; - for i, data in ipairs(sendq) do - send(data[1]); + for i, stanza in ipairs(sendq) do + send(stanza); sendq[i] = nil; end session.sendq = nil; @@ -389,10 +385,10 @@ end --- Helper to check that a session peer's certificate is valid local function check_cert_status(session) local host = session.direction == "outgoing" and session.to_host or session.from_host - local conn = session.conn:socket() + local conn = session.conn local cert - if conn.getpeercertificate then - cert = conn:getpeercertificate() + if conn.ssl_peercertificate then + cert = conn:ssl_peercertificate() end return module:fire_event("s2s-check-certificate", { host = host, session = session, cert = cert }); @@ -404,8 +400,7 @@ local function session_secure(session) session.secure = true; session.encrypted = true; - local sock = session.conn:socket(); - local info = sock.info and sock:info(); + local info = session.conn:ssl_info(); if type(info) == "table" then (session.log or log)("info", "Stream encrypted (%s with %s)", info.protocol, info.cipher); session.compressed = info.compression; @@ -935,6 +930,16 @@ local function friendly_cert_error(session) --> string elseif cert_errors:contains("self signed certificate") then return "is self-signed"; end + + local chain_errors = set.new(session.cert_chain_errors[2]); + for i, e in pairs(session.cert_chain_errors) do + if i > 2 then chain_errors:add_list(e); end + end + if chain_errors:contains("certificate has expired") then + return "has an expired certificate chain"; + elseif chain_errors:contains("No matching DANE TLSA records") then + return "does not match any DANE TLSA records"; + end end return "is not trusted"; -- for some other reason elseif session.cert_identity_status == "invalid" then diff --git a/plugins/mod_s2s_auth_certs.lua b/plugins/mod_s2s_auth_certs.lua index 992ee934..bde3cb82 100644 --- a/plugins/mod_s2s_auth_certs.lua +++ b/plugins/mod_s2s_auth_certs.lua @@ -9,7 +9,7 @@ local measure_cert_statuses = module:metric("counter", "checked", "", "Certifica module:hook("s2s-check-certificate", function(event) local session, host, cert = event.session, event.host, event.cert; - local conn = session.conn:socket(); + local conn = session.conn; local log = session.log or log; if not cert then @@ -18,8 +18,8 @@ module:hook("s2s-check-certificate", function(event) end local chain_valid, errors; - if conn.getpeerverification then - chain_valid, errors = conn:getpeerverification(); + if conn.ssl_peerverification then + chain_valid, errors = conn:ssl_peerverification(); else chain_valid, errors = false, { { "Chain verification not supported by this version of LuaSec" } }; end diff --git a/plugins/mod_saslauth.lua b/plugins/mod_saslauth.lua index ab863aa3..649f9ba6 100644 --- a/plugins/mod_saslauth.lua +++ b/plugins/mod_saslauth.lua @@ -242,7 +242,7 @@ module:hook("stanza/urn:ietf:params:xml:ns:xmpp-sasl:abort", function(event) end); local function tls_unique(self) - return self.userdata["tls-unique"]:getpeerfinished(); + return self.userdata["tls-unique"]:ssl_peerfinished(); end local mechanisms_attr = { xmlns='urn:ietf:params:xml:ns:xmpp-sasl' }; @@ -262,18 +262,17 @@ module:hook("stream-features", function(event) -- check whether LuaSec has the nifty binding to the function needed for tls-unique -- FIXME: would be nice to have this check only once and not for every socket if sasl_handler.add_cb_handler then - local socket = origin.conn:socket(); - local info = socket.info and socket:info(); - if info.protocol == "TLSv1.3" then + local info = origin.conn:ssl_info(); + if info and info.protocol == "TLSv1.3" then log("debug", "Channel binding 'tls-unique' undefined in context of TLS 1.3"); - elseif socket.getpeerfinished and socket:getpeerfinished() then + elseif origin.conn.ssl_peerfinished and origin.conn:ssl_peerfinished() then log("debug", "Channel binding 'tls-unique' supported"); sasl_handler:add_cb_handler("tls-unique", tls_unique); else log("debug", "Channel binding 'tls-unique' not supported (by LuaSec?)"); end sasl_handler["userdata"] = { - ["tls-unique"] = socket; + ["tls-unique"] = origin.conn; }; else log("debug", "Channel binding not supported by SASL handler"); diff --git a/plugins/mod_smacks.lua b/plugins/mod_smacks.lua index ce59248e..841e1208 100644 --- a/plugins/mod_smacks.lua +++ b/plugins/mod_smacks.lua @@ -2,7 +2,7 @@ -- -- Copyright (C) 2010-2015 Matthew Wild -- Copyright (C) 2010 Waqas Hussain --- Copyright (C) 2012-2021 Kim Alvefur +-- Copyright (C) 2012-2022 Kim Alvefur -- Copyright (C) 2012 Thijs Alkemade -- Copyright (C) 2014 Florian Zeitz -- Copyright (C) 2016-2020 Thilo Molitor @@ -10,6 +10,7 @@ -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +-- TODO unify sendq and smqueue local tonumber = tonumber; local tostring = tostring; @@ -322,26 +323,20 @@ end module:hook_tag(xmlns_sm2, "enable", function (session, stanza) return handle_enable(session, stanza, xmlns_sm2); end, 100); module:hook_tag(xmlns_sm3, "enable", function (session, stanza) return handle_enable(session, stanza, xmlns_sm3); end, 100); -module:hook_tag("http://etherx.jabber.org/streams", "features", - function (session, stanza) - -- Needs to be done after flushing sendq since those aren't stored as - -- stanzas and counting them is weird. - -- TODO unify sendq and smqueue - timer.add_task(1e-6, function () - if can_do_smacks(session) then - if stanza:get_child("sm", xmlns_sm3) then - session.sends2s(st.stanza("enable", sm3_attr)); - session.smacks = xmlns_sm3; - elseif stanza:get_child("sm", xmlns_sm2) then - session.sends2s(st.stanza("enable", sm2_attr)); - session.smacks = xmlns_sm2; - else - return; - end - wrap_session_out(session, false); - end - end); - end); +module:hook_tag("http://etherx.jabber.org/streams", "features", function(session, stanza) + if can_do_smacks(session) then + session.smacks_feature = stanza:get_child("sm", xmlns_sm3) or stanza:get_child("sm", xmlns_sm2); + end +end); + +module:hook("s2sout-established", function (event) + local session = event.session; + if not session.smacks_feature then return end + + session.smacks = session.smacks_feature.attr.xmlns; + wrap_session_out(session, false); + session.sends2s(st.stanza("enable", { xmlns = session.smacks })); +end); function handle_enabled(session, stanza, xmlns_sm) -- luacheck: ignore 212/stanza module:log("debug", "Enabling stream management"); diff --git a/plugins/mod_tls.lua b/plugins/mod_tls.lua index afc1653a..fc35b1d0 100644 --- a/plugins/mod_tls.lua +++ b/plugins/mod_tls.lua @@ -80,6 +80,9 @@ end module:hook_global("config-reloaded", module.load); local function can_do_tls(session) + if session.secure then + return false; + end if session.conn and not session.conn.starttls then if not session.secure then session.log("debug", "Underlying connection does not support STARTTLS"); @@ -126,6 +129,13 @@ end); module:hook("stanza/urn:ietf:params:xml:ns:xmpp-tls:starttls", function(event) local origin = event.origin; if can_do_tls(origin) then + if origin.conn.block_reads then + -- we need to ensure that no data is read anymore, otherwise we could end up in a situation where + -- <proceed/> is sent and the socket receives the TLS handshake (and passes the data to lua) before + -- it is asked to initiate TLS + -- (not with the classical single-threaded server backends) + origin.conn:block_reads() + end (origin.sends2s or origin.send)(starttls_proceed); if origin.destroyed then return end origin:reset_stream(); @@ -183,7 +193,7 @@ module:hook_tag(xmlns_starttls, "proceed", function (session, stanza) -- luachec if session.type == "s2sout_unauthed" and can_do_tls(session) then module:log("debug", "Proceeding with TLS on s2sout..."); session:reset_stream(); - session.conn:starttls(session.ssl_ctx); + session.conn:starttls(session.ssl_ctx, session.to_host); session.secure = false; return true; end diff --git a/spec/net_resolvers_service_spec.lua b/spec/net_resolvers_service_spec.lua new file mode 100644 index 00000000..53ce4754 --- /dev/null +++ b/spec/net_resolvers_service_spec.lua @@ -0,0 +1,241 @@ +local set = require "util.set"; + +insulate("net.resolvers.service", function () + local adns = { + resolver = function () + return { + lookup = function (_, cb, qname, qtype, qclass) + if qname == "_xmpp-server._tcp.example.com" + and (qtype or "SRV") == "SRV" + and (qclass or "IN") == "IN" then + cb({ + { -- 60+35+60 + srv = { target = "xmpp0-a.example.com", port = 5228, priority = 0, weight = 60 }; + }; + { + srv = { target = "xmpp0-b.example.com", port = 5216, priority = 0, weight = 35 }; + }; + { + srv = { target = "xmpp0-c.example.com", port = 5200, priority = 0, weight = 0 }; + }; + { + srv = { target = "xmpp0-d.example.com", port = 5256, priority = 0, weight = 120 }; + }; + + { + srv = { target = "xmpp1-a.example.com", port = 5273, priority = 1, weight = 30 }; + }; + { + srv = { target = "xmpp1-b.example.com", port = 5274, priority = 1, weight = 30 }; + }; + + { + srv = { target = "xmpp2.example.com", port = 5275, priority = 2, weight = 0 }; + }; + }); + elseif qname == "_xmpp-server._tcp.single.example.com" + and (qtype or "SRV") == "SRV" + and (qclass or "IN") == "IN" then + cb({ + { + srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 }; + }; + }); + elseif qname == "_xmpp-server._tcp.half.example.com" + and (qtype or "SRV") == "SRV" + and (qclass or "IN") == "IN" then + cb({ + { + srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 }; + }; + { + srv = { target = "xmpp0-b.example.com", port = 5270, priority = 0, weight = 1 }; + }; + }); + elseif qtype == "A" then + local l = qname:match("%-(%a)%.example.com$") or "1"; + local d = ("%d"):format(l:byte()) + cb({ + { + a = "127.0.0."..d; + }; + }); + elseif qtype == "AAAA" then + local l = qname:match("%-(%a)%.example.com$") or "1"; + local d = ("%04d"):format(l:byte()) + cb({ + { + aaaa = "fdeb:9619:649e:c7d9::"..d; + }; + }); + else + cb(nil); + end + end; + }; + end; + }; + package.loaded["net.adns"] = mock(adns); + local resolver = require "net.resolvers.service"; + math.randomseed(os.time()); + it("works for 99% of deployments", function () + -- Most deployments only have a single SRV record, let's make + -- sure that works okay + + local expected_targets = set.new({ + -- xmpp0-a + "tcp4 127.0.0.97 5269"; + "tcp6 fdeb:9619:649e:c7d9::0097 5269"; + }); + local received_targets = set.new({}); + + local r = resolver.new("single.example.com", "xmpp-server"); + local done = false; + local function handle_target(...) + if ... == nil then + done = true; + -- No more targets + return; + end + received_targets:add(table.concat({ ... }, " ", 1, 3)); + end + r:next(handle_target); + while not done do + r:next(handle_target); + end + + -- We should have received all expected targets, and no unexpected + -- ones: + assert.truthy(set.xor(received_targets, expected_targets):empty()); + end); + + it("supports A/AAAA fallback", function () + -- Many deployments don't have any SRV records, so we should + -- fall back to A/AAAA records instead when that is the case + + local expected_targets = set.new({ + -- xmpp0-a + "tcp4 127.0.0.97 5269"; + "tcp6 fdeb:9619:649e:c7d9::0097 5269"; + }); + local received_targets = set.new({}); + + local r = resolver.new("xmpp0-a.example.com", "xmpp-server", "tcp", { default_port = 5269 }); + local done = false; + local function handle_target(...) + if ... == nil then + done = true; + -- No more targets + return; + end + received_targets:add(table.concat({ ... }, " ", 1, 3)); + end + r:next(handle_target); + while not done do + r:next(handle_target); + end + + -- We should have received all expected targets, and no unexpected + -- ones: + assert.truthy(set.xor(received_targets, expected_targets):empty()); + end); + + + it("works", function () + local expected_targets = set.new({ + -- xmpp0-a + "tcp4 127.0.0.97 5228"; + "tcp6 fdeb:9619:649e:c7d9::0097 5228"; + "tcp4 127.0.0.97 5273"; + "tcp6 fdeb:9619:649e:c7d9::0097 5273"; + + -- xmpp0-b + "tcp4 127.0.0.98 5274"; + "tcp6 fdeb:9619:649e:c7d9::0098 5274"; + "tcp4 127.0.0.98 5216"; + "tcp6 fdeb:9619:649e:c7d9::0098 5216"; + + -- xmpp0-c + "tcp4 127.0.0.99 5200"; + "tcp6 fdeb:9619:649e:c7d9::0099 5200"; + + -- xmpp0-d + "tcp4 127.0.0.100 5256"; + "tcp6 fdeb:9619:649e:c7d9::0100 5256"; + + -- xmpp2 + "tcp4 127.0.0.49 5275"; + "tcp6 fdeb:9619:649e:c7d9::0049 5275"; + + }); + local received_targets = set.new({}); + + local r = resolver.new("example.com", "xmpp-server"); + local done = false; + local function handle_target(...) + if ... == nil then + done = true; + -- No more targets + return; + end + received_targets:add(table.concat({ ... }, " ", 1, 3)); + end + r:next(handle_target); + while not done do + r:next(handle_target); + end + + -- We should have received all expected targets, and no unexpected + -- ones: + assert.truthy(set.xor(received_targets, expected_targets):empty()); + end); + + it("balances across weights correctly #slow", function () + -- This mimics many repeated connections to 'example.com' (mock + -- records defined above), and records the port number of the + -- first target. Therefore it (should) only return priority + -- 0 records, and the input data is constructed such that the + -- last two digits of the port number represent the percentage + -- of times that record should (on average) be picked first. + + -- To prevent random test failures, we test across a handful + -- of fixed (randomly selected) seeds. + for _, seed in ipairs({ 8401877, 3943829, 7830992 }) do + math.randomseed(seed); + + local results = {}; + local function run() + local run_results = {}; + local r = resolver.new("example.com", "xmpp-server"); + local function record_target(...) + if ... == nil then + -- No more targets + return; + end + run_results = { ... }; + end + r:next(record_target); + return run_results[3]; + end + + for _ = 1, 1000 do + local port = run(); + results[port] = (results[port] or 0) + 1; + end + + local ports = {}; + for port in pairs(results) do + table.insert(ports, port); + end + table.sort(ports); + for _, port in ipairs(ports) do + --print("PORT", port, tostring((results[port]/1000) * 100).."% hits (expected "..tostring(port-5200).."%)"); + local hit_pct = (results[port]/1000) * 100; + local expected_pct = port - 5200; + --print(hit_pct, expected_pct, math.abs(hit_pct - expected_pct)); + assert.is_true(math.abs(hit_pct - expected_pct) < 5); + end + --print("---"); + end + end); +end); diff --git a/spec/util_poll_spec.lua b/spec/util_poll_spec.lua index a763be90..05318453 100644 --- a/spec/util_poll_spec.lua +++ b/spec/util_poll_spec.lua @@ -1,6 +1,35 @@ -describe("util.poll", function () - it("loads", function () - require "util.poll" +describe("util.poll", function() + local poll; + setup(function() + poll = require "util.poll"; end); + it("loads", function() + assert.is_table(poll); + assert.is_function(poll.new); + assert.is_string(poll.api); + end); + describe("new", function() + local p; + setup(function() + p = poll.new(); + end) + it("times out", function () + local fd, err = p:wait(0); + assert.falsy(fd); + assert.equal("timeout", err); + end); + it("works", function() + -- stdout should be writable, right? + assert.truthy(p:add(1, false, true)); + local fd, r, w = p:wait(1); + assert.is_number(fd); + assert.is_boolean(r); + assert.is_boolean(w); + assert.equal(1, fd); + assert.falsy(r); + assert.truthy(w); + assert.truthy(p:del(1)); + end); + end) end); diff --git a/spec/util_table_spec.lua b/spec/util_table_spec.lua index 76f54b69..a0535c08 100644 --- a/spec/util_table_spec.lua +++ b/spec/util_table_spec.lua @@ -12,6 +12,17 @@ describe("util.table", function () assert.same({ "lorem", "ipsum", "dolor", "sit", "amet", n = 5 }, u_table.pack("lorem", "ipsum", "dolor", "sit", "amet")); end); end); + + describe("move()", function () + it("works", function () + local t1 = { "apple", "banana", "carrot" }; + local t2 = { "cat", "donkey", "elephant" }; + local t3 = {}; + u_table.move(t1, 1, 3, 1, t3); + u_table.move(t2, 1, 3, 3, t3); + assert.same({ "apple", "banana", "cat", "donkey", "elephant" }, t3); + end); + end); end); diff --git a/teal-src/module.d.tl b/teal-src/module.d.tl index 67b2437c..cb7771e2 100644 --- a/teal-src/module.d.tl +++ b/teal-src/module.d.tl @@ -62,7 +62,12 @@ global record moduleapi send_iq : function (moduleapi, st.stanza_t, util_session, number) broadcast : function (moduleapi, { string }, st.stanza_t, function) type timer_callback = function (number, ... : any) : number - add_timer : function (moduleapi, number, timer_callback, ... : any) + record timer_wrapper + stop : function (timer_wrapper) + disarm : function (timer_wrapper) + reschedule : function (timer_wrapper, number) + end + add_timer : function (moduleapi, number, timer_callback, ... : any) : timer_wrapper get_directory : function (moduleapi) : string enum file_mode "r" "w" "a" "r+" "w+" "a+" diff --git a/teal-src/plugins/mod_cron.tl b/teal-src/plugins/mod_cron.tl index f3b8f62f..7fa2a36b 100644 --- a/teal-src/plugins/mod_cron.tl +++ b/teal-src/plugins/mod_cron.tl @@ -88,8 +88,8 @@ local function run_task(task : task_spec) task:save(started_at); end -local task_runner = async.runner(run_task); -module:add_timer(1, function() : integer +local task_runner : async.runner_t<task_spec> = async.runner(run_task); +scheduled = module:add_timer(1, function() : integer module:log("info", "Running periodic tasks"); local delay = 3600; for host in pairs(active_hosts) do diff --git a/teal-src/util/async.d.tl b/teal-src/util/async.d.tl new file mode 100644 index 00000000..a2e41cd6 --- /dev/null +++ b/teal-src/util/async.d.tl @@ -0,0 +1,42 @@ +local record lib + ready : function () : boolean + waiter : function (num : integer, allow_many : boolean) : function (), function () + guarder : function () : function (id : function ()) : function () | nil + record runner_t<T> + func : function (T) + thread : thread + enum state_e + -- from Lua manual + "running" + "suspended" + "normal" + "dead" + + -- from util.async + "ready" + "error" + end + state : state_e + notified_state : state_e + queue : { T } + type watcher_t = function (runner_t<T>, ... : any) + type watchers_t = { state_e : watcher_t } + data : any + id : string + + run : function (runner_t<T>, T) : boolean, state_e, integer + enqueue : function (runner_t<T>, T) : runner_t<T> + log : function (runner_t<T>, string, string, ... : any) + onready : function (runner_t<T>, function) : runner_t<T> + onready : function (runner_t<T>, function) : runner_t<T> + onwaiting : function (runner_t<T>, function) : runner_t<T> + onerror : function (runner_t<T>, function) : runner_t<T> + end + runner : function <T>(function (T), runner_t.watchers_t, any) : runner_t<T> + wait_for : function (any) : any, any + sleep : function (t:number) + + -- set_nexttick = function(new_next_tick) next_tick = new_next_tick; end; + -- set_schedule_function = function (new_schedule_function) schedule_task = new_schedule_function; end; +end +return lib diff --git a/util-src/crand.c b/util-src/crand.c index 160ac1f6..e4104787 100644 --- a/util-src/crand.c +++ b/util-src/crand.c @@ -45,7 +45,7 @@ #endif /* This wasn't present before glibc 2.25 */ -int getrandom(void *buf, size_t buflen, unsigned int flags) { +static int getrandom(void *buf, size_t buflen, unsigned int flags) { return syscall(SYS_getrandom, buf, buflen, flags); } #else @@ -66,7 +66,7 @@ int getrandom(void *buf, size_t buflen, unsigned int flags) { #define SMALLBUFSIZ 32 #endif -int Lrandom(lua_State *L) { +static int Lrandom(lua_State *L) { char smallbuf[SMALLBUFSIZ]; char *buf = &smallbuf[0]; const lua_Integer l = luaL_checkinteger(L, 1); diff --git a/util-src/strbitop.c b/util-src/strbitop.c index 89fce661..fda8917a 100644 --- a/util-src/strbitop.c +++ b/util-src/strbitop.c @@ -14,7 +14,7 @@ /* TODO Deduplicate code somehow */ -int strop_and(lua_State *L) { +static int strop_and(lua_State *L) { luaL_Buffer buf; size_t a, b, i; const char *str_a = luaL_checklstring(L, 1, &a); @@ -35,7 +35,7 @@ int strop_and(lua_State *L) { return 1; } -int strop_or(lua_State *L) { +static int strop_or(lua_State *L) { luaL_Buffer buf; size_t a, b, i; const char *str_a = luaL_checklstring(L, 1, &a); @@ -56,7 +56,7 @@ int strop_or(lua_State *L) { return 1; } -int strop_xor(lua_State *L) { +static int strop_xor(lua_State *L) { luaL_Buffer buf; size_t a, b, i; const char *str_a = luaL_checklstring(L, 1, &a); diff --git a/util-src/table.c b/util-src/table.c index 9a9553fc..4bbceedb 100644 --- a/util-src/table.c +++ b/util-src/table.c @@ -1,11 +1,21 @@ #include <lua.h> #include <lauxlib.h> +#ifndef LUA_MAXINTEGER +#include <stdint.h> +#define LUA_MAXINTEGER PTRDIFF_MAX +#endif + +#if (LUA_VERSION_NUM > 501) +#define lua_equal(L, A, B) lua_compare(L, A, B, LUA_OPEQ) +#endif + static int Lcreate_table(lua_State *L) { lua_createtable(L, luaL_checkinteger(L, 1), luaL_checkinteger(L, 2)); return 1; } +/* COMPAT: w/ Lua pre-5.4 */ static int Lpack(lua_State *L) { unsigned int n_args = lua_gettop(L); lua_createtable(L, n_args, 1); @@ -20,6 +30,40 @@ static int Lpack(lua_State *L) { return 1; } +/* COMPAT: w/ Lua pre-5.4 */ +static int Lmove (lua_State *L) { + lua_Integer f = luaL_checkinteger(L, 2); + lua_Integer e = luaL_checkinteger(L, 3); + lua_Integer t = luaL_checkinteger(L, 4); + + int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ + luaL_checktype(L, 1, LUA_TTABLE); + luaL_checktype(L, tt, LUA_TTABLE); + + if (e >= f) { /* otherwise, nothing to move */ + lua_Integer n, i; + luaL_argcheck(L, f > 0 || e < LUA_MAXINTEGER + f, 3, + "too many elements to move"); + n = e - f + 1; /* number of elements to move */ + luaL_argcheck(L, t <= LUA_MAXINTEGER - n + 1, 4, + "destination wrap around"); + if (t > e || t <= f || (tt != 1 && !lua_equal(L, 1, tt))) { + for (i = 0; i < n; i++) { + lua_rawgeti(L, 1, f + i); + lua_rawseti(L, tt, t + i); + } + } else { + for (i = n - 1; i >= 0; i--) { + lua_rawgeti(L, 1, f + i); + lua_rawseti(L, tt, t + i); + } + } + } + + lua_pushvalue(L, tt); /* return destination table */ + return 1; +} + int luaopen_util_table(lua_State *L) { #if (LUA_VERSION_NUM > 501) luaL_checkversion(L); @@ -29,5 +73,7 @@ int luaopen_util_table(lua_State *L) { lua_setfield(L, -2, "create"); lua_pushcfunction(L, Lpack); lua_setfield(L, -2, "pack"); + lua_pushcfunction(L, Lmove); + lua_setfield(L, -2, "move"); return 1; } diff --git a/util/array.lua b/util/array.lua index c33a5ef1..9d438940 100644 --- a/util/array.lua +++ b/util/array.lua @@ -8,6 +8,7 @@ local t_insert, t_sort, t_remove, t_concat = table.insert, table.sort, table.remove, table.concat; +local t_move = require "util.table".move; local setmetatable = setmetatable; local getmetatable = getmetatable; @@ -137,13 +138,11 @@ function array_base.slice(outa, ina, i, j) return outa; end - for idx = 1, 1+j-i do - outa[idx] = ina[i+(idx-1)]; - end + + t_move(ina, i, j, 1, outa); if ina == outa then - for idx = 2+j-i, #outa do - outa[idx] = nil; - end + -- Clear (nil) remainder of range + t_move(ina, #outa+1, #outa*2, 2+j-i, ina); end return outa; end @@ -209,10 +208,7 @@ function array_methods:shuffle() end function array_methods:append(ina) - local len, len2 = #self, #ina; - for i = 1, len2 do - self[len+i] = ina[i]; - end + t_move(ina, 1, #ina, #self+1, self); return self; end diff --git a/util/logger.lua b/util/logger.lua index 20a5cef2..148b98dc 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -10,6 +10,7 @@ local pairs = pairs; local ipairs = ipairs; local require = require; +local t_remove = table.remove; local _ENV = nil; -- luacheck: std none @@ -78,6 +79,20 @@ local function add_simple_sink(simple_sink_function, levels) for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do add_level_sink(level, sink_function); end + return sink_function; +end + +local function remove_sink(sink_function) + local removed; + for level, sinks in pairs(level_sinks) do + for i = #sinks, 1, -1 do + if sinks[i] == sink_function then + t_remove(sinks, i); + removed = true; + end + end + end + return removed; end return { @@ -87,4 +102,5 @@ return { add_level_sink = add_level_sink; add_simple_sink = add_simple_sink; new = make_logger; + remove_sink = remove_sink; }; diff --git a/util/prosodyctl/shell.lua b/util/prosodyctl/shell.lua index bce27b94..0b1dd3f9 100644 --- a/util/prosodyctl/shell.lua +++ b/util/prosodyctl/shell.lua @@ -89,11 +89,15 @@ local function start(arg) --luacheck: ignore 212/arg local errors = 0; -- TODO This is weird, but works for now. client.events.add_handler("received", function(stanza) if stanza.name == "repl-output" or stanza.name == "repl-result" then + local dest = io.stdout; if stanza.attr.type == "error" then errors = errors + 1; - io.stderr:write(stanza:get_text(), "\n"); + dest = io.stderr; + end + if stanza.attr.eol == "0" then + dest:write(stanza:get_text()); else - print(stanza:get_text()); + dest:write(stanza:get_text(), "\n"); end end if stanza.name == "repl-result" then diff --git a/util/sslconfig.lua b/util/sslconfig.lua index 6074a1fb..0078365b 100644 --- a/util/sslconfig.lua +++ b/util/sslconfig.lua @@ -3,9 +3,12 @@ local type = type; local pairs = pairs; local rawset = rawset; +local rawget = rawget; +local error = error; local t_concat = table.concat; local t_insert = table.insert; local setmetatable = setmetatable; +local resolve_path = require"util.paths".resolve_relative_path; local _ENV = nil; -- luacheck: std none @@ -34,7 +37,7 @@ function handlers.options(config, field, new) options[value] = true; end end - config[field] = options; + rawset(config, field, options) end handlers.verifyext = handlers.options; @@ -70,6 +73,20 @@ finalisers.curveslist = finalisers.ciphers; -- TLS 1.3 ciphers finalisers.ciphersuites = finalisers.ciphers; +-- Path expansion +function finalisers.key(path, config) + if type(path) == "string" then + return resolve_path(config._basedir, path); + else + return nil + end +end +finalisers.certificate = finalisers.key; +finalisers.cafile = finalisers.key; +finalisers.capath = finalisers.key; +-- XXX: copied from core/certmanager.lua, but this seems odd, because it would remove a dhparam function from the config +finalisers.dhparam = finalisers.key; + -- protocol = "x" should enable only that protocol -- protocol = "x+" should enable x and later versions @@ -89,37 +106,81 @@ end -- Merge options from 'new' config into 'config' local function apply(config, new) + rawset(config, "_cache", nil); if type(new) == "table" then for field, value in pairs(new) do - (handlers[field] or rawset)(config, field, value); + -- exclude keys which are internal to the config builder + if field:sub(1, 1) ~= "_" then + (handlers[field] or rawset)(config, field, value); + end end end + return config end -- Finalize the config into the form LuaSec expects local function final(config) local output = { }; for field, value in pairs(config) do - output[field] = (finalisers[field] or id)(value); + -- exclude keys which are internal to the config builder + if field:sub(1, 1) ~= "_" then + output[field] = (finalisers[field] or id)(value, config); + end end -- Need to handle protocols last because it adds to the options list protocol(output); return output; end +local function build(config) + local cached = rawget(config, "_cache"); + if cached then + return cached, nil + end + + local ctx, err = rawget(config, "_context_factory")(config:final(), config); + if ctx then + rawset(config, "_cache", ctx); + end + return ctx, err +end + local sslopts_mt = { __index = { apply = apply; final = final; + build = build; }; + __newindex = function() + error("SSL config objects cannot be modified directly. Use :apply()") + end; }; -local function new() - return setmetatable({options={}}, sslopts_mt); + +-- passing basedir through everything is required to avoid sslconfig depending +-- on prosody.paths.config +local function new(context_factory, basedir) + return setmetatable({ + _context_factory = context_factory, + _basedir = basedir, + options={}, + }, sslopts_mt); end +local function clone(config) + local result = new(); + for k, v in pairs(config) do + -- note that we *do* copy the internal keys on clone -- we have to carry + -- both the factory and the cache with us + rawset(result, k, v); + end + return result +end + +sslopts_mt.__index.clone = clone; + return { apply = apply; final = final; - new = new; + _new = new; }; diff --git a/util/stanza.lua b/util/stanza.lua index a38f80b3..a14be5f0 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -21,6 +21,8 @@ local type = type; local s_gsub = string.gsub; local s_sub = string.sub; local s_find = string.find; +local t_move = table.move or require "util.table".move; +local t_create = require"util.table".create; local valid_utf8 = require "util.encodings".utf8.valid; @@ -275,25 +277,33 @@ function stanza_mt:find(path) end local function _clone(stanza, only_top) - local attr, tags = {}, {}; + local attr = {}; for k,v in pairs(stanza.attr) do attr[k] = v; end local old_namespaces, namespaces = stanza.namespaces; if old_namespaces then namespaces = {}; for k,v in pairs(old_namespaces) do namespaces[k] = v; end end - local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; + local tags, new; + if only_top then + tags = {}; + new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; + else + tags = t_create(#stanza.tags, 0); + new = t_create(#stanza, 4); + new.name = stanza.name; + new.attr = attr; + new.namespaces = namespaces; + new.tags = tags; + end + + setmetatable(new, stanza_mt); if not only_top then - for i=1,#stanza do - local child = stanza[i]; - if child.name then - child = _clone(child); - t_insert(tags, child); - end - t_insert(new, child); - end + t_move(stanza, 1, #stanza, 1, new); + t_move(stanza.tags, 1, #stanza.tags, 1, tags); + new:maptags(_clone); end - return setmetatable(new, stanza_mt); + return new; end local function clone(stanza, only_top) |