aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.luacheckrc2
-rw-r--r--CHANGES20
-rw-r--r--GNUmakefile3
-rw-r--r--core/certmanager.lua36
-rw-r--r--core/configmanager.lua15
-rw-r--r--core/portmanager.lua21
-rw-r--r--doc/doap.xml8
-rw-r--r--makefile3
-rw-r--r--net/connect.lua45
-rw-r--r--net/http/codes.lua92
-rw-r--r--net/resolvers/basic.lua160
-rw-r--r--net/resolvers/service.lua81
-rw-r--r--net/server.lua7
-rw-r--r--net/server_epoll.lua66
-rw-r--r--net/server_event.lua42
-rw-r--r--net/server_select.lua29
-rw-r--r--net/tls_luasec.lua89
-rw-r--r--plugins/adhoc/adhoc.lib.lua2
-rw-r--r--plugins/adhoc/mod_adhoc.lua4
-rw-r--r--plugins/mod_admin_shell.lua121
-rw-r--r--plugins/mod_c2s.lua6
-rw-r--r--plugins/mod_csi_simple.lua5
-rw-r--r--plugins/mod_debug_stanzas/watcher.lib.lua220
-rw-r--r--plugins/mod_mam/mod_mam.lua8
-rw-r--r--plugins/mod_s2s.lua53
-rw-r--r--plugins/mod_s2s_auth_certs.lua6
-rw-r--r--plugins/mod_saslauth.lua21
-rw-r--r--plugins/mod_smacks.lua86
-rw-r--r--plugins/mod_tls.lua12
-rw-r--r--spec/net_resolvers_service_spec.lua241
-rw-r--r--spec/util_poll_spec.lua35
-rw-r--r--spec/util_table_spec.lua11
-rw-r--r--teal-src/module.d.tl7
-rw-r--r--teal-src/plugins/mod_cron.tl4
-rw-r--r--teal-src/util/async.d.tl42
-rw-r--r--teal-src/util/stanza.d.tl88
-rw-r--r--util-src/crand.c4
-rw-r--r--util-src/strbitop.c6
-rw-r--r--util-src/table.c46
-rw-r--r--util/array.lua16
-rw-r--r--util/logger.lua16
-rw-r--r--util/openmetrics.lua7
-rw-r--r--util/prosodyctl/shell.lua19
-rw-r--r--util/sslconfig.lua73
-rw-r--r--util/stanza.lua32
-rw-r--r--util/vcard.lua574
-rw-r--r--util/watchdog.lua39
47 files changed, 1561 insertions, 962 deletions
diff --git a/.luacheckrc b/.luacheckrc
index 0c934d90..d08dfa70 100644
--- a/.luacheckrc
+++ b/.luacheckrc
@@ -1,6 +1,6 @@
cache = true
codes = true
-ignore = { "411/err", "421/err", "411/ok", "421/ok", "211/_ENV", "431/log", }
+ignore = { "411/err", "421/err", "411/ok", "421/ok", "211/_ENV", "431/log", "214", "581" }
std = "lua53c"
max_line_length = 150
diff --git a/CHANGES b/CHANGES
index d963f310..213be9da 100644
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,23 @@
+TRUNK
+=====
+
+## New
+
+### Administration
+
+- Add 'watch log' command to follow live debug logs at runtime (even if disabled)
+
+### Networking
+
+- Honour 'weight' parameter during SRV record selection
+- Support for RFC 8305 "Happy Eyeballs" to improve IPv4/IPv6 connectivity
+- Support for TCP Fast Open in server_epoll (pending LuaSocket support)
+- Support for deferred accept in server_epoll (pending LuaSocket support)
+
+### Security and authentication
+
+- Advertise supported SASL Channel-Binding types (XEP-0440)
+
0.12.0
======
diff --git a/GNUmakefile b/GNUmakefile
index e9ec78c4..c8d2d3dd 100644
--- a/GNUmakefile
+++ b/GNUmakefile
@@ -71,12 +71,13 @@ install-util: util/encodings.so util/encodings.so util/pposix.so util/signal.so
install-plugins:
$(MKDIR) $(MODULES)
- $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam
+ $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam $(MODULES)/mod_debug_stanzas
$(INSTALL_DATA) plugins/*.lua $(MODULES)
$(INSTALL_DATA) plugins/mod_pubsub/*.lua $(MODULES)/mod_pubsub
$(INSTALL_DATA) plugins/adhoc/*.lua $(MODULES)/adhoc
$(INSTALL_DATA) plugins/muc/*.lua $(MODULES)/muc
$(INSTALL_DATA) plugins/mod_mam/*.lua $(MODULES)/mod_mam
+ $(INSTALL_DATA) plugins/mod_debug_stanzas/*.lua $(MODULES)/mod_debug_stanzas
install-man:
$(MKDIR) $(MAN)/man1
diff --git a/core/certmanager.lua b/core/certmanager.lua
index 7a82c786..0c71e448 100644
--- a/core/certmanager.lua
+++ b/core/certmanager.lua
@@ -9,9 +9,8 @@
local ssl = require "ssl";
local configmanager = require "core.configmanager";
local log = require "util.logger".init("certmanager");
-local ssl_context = ssl.context or require "ssl.context";
local ssl_newcontext = ssl.newcontext;
-local new_config = require"util.sslconfig".new;
+local new_config = require"net.server".tls_builder;
local stat = require "lfs".attributes;
local x509 = require "util.x509";
@@ -313,10 +312,6 @@ else
core_defaults.curveslist = nil;
end
-local path_options = { -- These we pass through resolve_path()
- key = true, certificate = true, cafile = true, capath = true, dhparam = true
-}
-
local function create_context(host, mode, ...)
local cfg = new_config();
cfg:apply(core_defaults);
@@ -352,34 +347,7 @@ local function create_context(host, mode, ...)
if user_ssl_config.certificate and not user_ssl_config.key then return nil, "No key present in SSL/TLS configuration for "..host; end
end
- for option in pairs(path_options) do
- if type(user_ssl_config[option]) == "string" then
- user_ssl_config[option] = resolve_path(config_path, user_ssl_config[option]);
- else
- user_ssl_config[option] = nil;
- end
- end
-
- -- LuaSec expects dhparam to be a callback that takes two arguments.
- -- We ignore those because it is mostly used for having a separate
- -- set of params for EXPORT ciphers, which we don't have by default.
- if type(user_ssl_config.dhparam) == "string" then
- local f, err = io_open(user_ssl_config.dhparam);
- if not f then return nil, "Could not open DH parameters: "..err end
- local dhparam = f:read("*a");
- f:close();
- user_ssl_config.dhparam = function() return dhparam; end
- end
-
- local ctx, err = ssl_newcontext(user_ssl_config);
-
- -- COMPAT Older LuaSec ignores the cipher list from the config, so we have to take care
- -- of it ourselves (W/A for #x)
- if ctx and user_ssl_config.ciphers then
- local success;
- success, err = ssl_context.setcipher(ctx, user_ssl_config.ciphers);
- if not success then ctx = nil; end
- end
+ local ctx, err = cfg:build();
if not ctx then
err = err or "invalid ssl config"
diff --git a/core/configmanager.lua b/core/configmanager.lua
index 092b3946..4b8df96e 100644
--- a/core/configmanager.lua
+++ b/core/configmanager.lua
@@ -40,16 +40,10 @@ function _M.getconfig()
return config;
end
-function _M.get(host, key, _oldkey)
- if key == "core" then
- key = _oldkey; -- COMPAT with code that still uses "core"
- end
+function _M.get(host, key)
return config[host][key];
end
-function _M.rawget(host, key, _oldkey)
- if key == "core" then
- key = _oldkey; -- COMPAT with code that still uses "core"
- end
+function _M.rawget(host, key)
local hostconfig = rawget(config, host);
if hostconfig then
return rawget(hostconfig, key);
@@ -68,10 +62,7 @@ local function set(config_table, host, key, value)
return false;
end
-function _M.set(host, key, value, _oldvalue)
- if key == "core" then
- key, value = value, _oldvalue; --COMPAT with code that still uses "core"
- end
+function _M.set(host, key, value)
return set(config, host, key, value);
end
diff --git a/core/portmanager.lua b/core/portmanager.lua
index 38c74b66..8c7dfddb 100644
--- a/core/portmanager.lua
+++ b/core/portmanager.lua
@@ -240,21 +240,22 @@ local function add_sni_host(host, service)
log("debug", "Gathering certificates for SNI for host %s, %s service", host, service or "default");
for name, interface, port, n, active_service --luacheck: ignore 213
in active_services:iter(service, nil, nil, nil) do
- if active_service.server.hosts and active_service.tls_cfg then
- local config_prefix = (active_service.config_prefix or name).."_";
- if config_prefix == "_" then config_prefix = ""; end
- local prefix_ssl_config = config.get(host, config_prefix.."ssl");
+ if active_service.server and active_service.tls_cfg then
local alternate_host = name and config.get(host, name.."_host");
if not alternate_host and name == "https" then
-- TODO should this be some generic thing? e.g. in the service definition
alternate_host = config.get(host, "http_host");
end
local autocert = certmanager.find_host_cert(alternate_host or host);
- -- luacheck: ignore 211/cfg
- local ssl, err, cfg = certmanager.create_context(host, "server", prefix_ssl_config, autocert, active_service.tls_cfg);
- if ssl then
- active_service.server.hosts[alternate_host or host] = ssl;
- else
+ local manualcert = active_service.tls_cfg;
+ local certificate = (autocert and autocert.certificate) or manualcert.certificate;
+ local key = (autocert and autocert.key) or manualcert.key;
+ local ok, err = active_service.server:sslctx():set_sni_host(
+ host,
+ certificate,
+ key
+ );
+ if not ok then
log("error", "Error creating TLS context for SNI host %s: %s", host, err);
end
end
@@ -277,7 +278,7 @@ prosody.events.add_handler("host-deactivated", function (host)
for name, interface, port, n, active_service --luacheck: ignore 213
in active_services:iter(nil, nil, nil, nil) do
if active_service.tls_cfg then
- active_service.server.hosts[host] = nil;
+ active_service.server:sslctx():remove_sni_host(host)
end
end
end);
diff --git a/doc/doap.xml b/doc/doap.xml
index fa3893f8..e767115b 100644
--- a/doc/doap.xml
+++ b/doc/doap.xml
@@ -845,5 +845,13 @@
<xmpp:note>Broken out of XEP-0313</xmpp:note>
</xmpp:SupportedXep>
</implements>
+ <implements>
+ <xmpp:SupportedXep>
+ <xmpp:xep rdf:resource="https://xmpp.org/extensions/xep-0440.html"/>
+ <xmpp:version>0.2.0</xmpp:version>
+ <xmpp:since>trunk</xmpp:since>
+ <xmpp:status>complete</xmpp:status>
+ </xmpp:SupportedXep>
+ </implements>
</Project>
</rdf:RDF>
diff --git a/makefile b/makefile
index cd6340b7..7f17de03 100644
--- a/makefile
+++ b/makefile
@@ -73,12 +73,13 @@ install-util: util/encodings.so util/encodings.so util/pposix.so util/signal.so
install-plugins:
$(MKDIR) $(MODULES)
- $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam
+ $(MKDIR) $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam $(MODULES)/mod_debug_stanzas
$(INSTALL_DATA) plugins/*.lua $(MODULES)
$(INSTALL_DATA) plugins/mod_pubsub/*.lua $(MODULES)/mod_pubsub
$(INSTALL_DATA) plugins/adhoc/*.lua $(MODULES)/adhoc
$(INSTALL_DATA) plugins/muc/*.lua $(MODULES)/muc
$(INSTALL_DATA) plugins/mod_mam/*.lua $(MODULES)/mod_mam
+ $(INSTALL_DATA) plugins/mod_debug_stanzas/*.lua $(MODULES)/mod_debug_stanzas
install-man:
$(MKDIR) $(MAN)/man1
diff --git a/net/connect.lua b/net/connect.lua
index 4b602be4..d85afcff 100644
--- a/net/connect.lua
+++ b/net/connect.lua
@@ -1,6 +1,7 @@
local server = require "net.server";
local log = require "util.logger".init("net.connect");
local new_id = require "util.id".short;
+local timer = require "util.timer";
-- TODO #1246 Happy Eyeballs
-- FIXME RFC 6724
@@ -28,16 +29,17 @@ local pending_connection_listeners = {};
local function attempt_connection(p)
p:log("debug", "Checking for targets...");
- if p.conn then
- pending_connections_map[p.conn] = nil;
- p.conn = nil;
- end
- p.target_resolver:next(function (conn_type, ip, port, extra)
+ p.target_resolver:next(function (conn_type, ip, port, extra, more_targets_available)
if not conn_type then
-- No more targets to try
p:log("debug", "No more connection targets to try", p.target_resolver.last_error);
- if p.listeners.onfail then
- p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service");
+ if next(p.conns) == nil then
+ p:log("debug", "No more targets, no pending connections. Connection failed.");
+ if p.listeners.onfail then
+ p.listeners.onfail(p.data, p.last_error or p.target_resolver.last_error or "unable to resolve service");
+ end
+ else
+ p:log("debug", "One or more connection attempts are still pending. Waiting for now.");
end
return;
end
@@ -49,8 +51,16 @@ local function attempt_connection(p)
p.last_error = err or "unknown reason";
return attempt_connection(p);
end
- p.conn = conn;
+ p.conns[conn] = true;
pending_connections_map[conn] = p;
+ if more_targets_available then
+ timer.add_task(0.250, function ()
+ if not p.connected then
+ p:log("debug", "Still not connected, making parallel connection attempt...");
+ attempt_connection(p);
+ end
+ end);
+ end
end);
end
@@ -62,6 +72,13 @@ function pending_connection_listeners.onconnect(conn)
return;
end
pending_connections_map[conn] = nil;
+ if p.connected then
+ -- We already succeeded in connecting
+ p.conns[conn] = nil;
+ conn:close();
+ return;
+ end
+ p.connected = true;
p:log("debug", "Successfully connected");
conn:setlistener(p.listeners, p.data);
return p.listeners.onconnect(conn);
@@ -73,9 +90,18 @@ function pending_connection_listeners.ondisconnect(conn, reason)
log("warn", "Failed connection, but unexpected!");
return;
end
+ p.conns[conn] = nil;
+ pending_connections_map[conn] = nil;
p.last_error = reason or "unknown reason";
p:log("debug", "Connection attempt failed: %s", p.last_error);
- attempt_connection(p);
+ if p.connected then
+ p:log("debug", "Connection already established, ignoring failure");
+ elseif next(p.conns) == nil then
+ p:log("debug", "No pending connection attempts, and not yet connected");
+ attempt_connection(p);
+ else
+ p:log("debug", "Other attempts are still pending, ignoring failure");
+ end
end
local function connect(target_resolver, listeners, options, data)
@@ -85,6 +111,7 @@ local function connect(target_resolver, listeners, options, data)
listeners = assert(listeners);
options = options or {};
data = data;
+ conns = {};
}, pending_connection_mt);
p:log("debug", "Starting connection process");
diff --git a/net/http/codes.lua b/net/http/codes.lua
index 4327f151..b2949286 100644
--- a/net/http/codes.lua
+++ b/net/http/codes.lua
@@ -2,62 +2,62 @@
local response_codes = {
-- Source: http://www.iana.org/assignments/http-status-codes
- [100] = "Continue"; -- RFC7231, Section 6.2.1
- [101] = "Switching Protocols"; -- RFC7231, Section 6.2.2
+ [100] = "Continue"; -- RFC9110, Section 15.2.1
+ [101] = "Switching Protocols"; -- RFC9110, Section 15.2.2
[102] = "Processing";
[103] = "Early Hints";
-- [104-199] = "Unassigned";
- [200] = "OK"; -- RFC7231, Section 6.3.1
- [201] = "Created"; -- RFC7231, Section 6.3.2
- [202] = "Accepted"; -- RFC7231, Section 6.3.3
- [203] = "Non-Authoritative Information"; -- RFC7231, Section 6.3.4
- [204] = "No Content"; -- RFC7231, Section 6.3.5
- [205] = "Reset Content"; -- RFC7231, Section 6.3.6
- [206] = "Partial Content"; -- RFC7233, Section 4.1
+ [200] = "OK"; -- RFC9110, Section 15.3.1
+ [201] = "Created"; -- RFC9110, Section 15.3.2
+ [202] = "Accepted"; -- RFC9110, Section 15.3.3
+ [203] = "Non-Authoritative Information"; -- RFC9110, Section 15.3.4
+ [204] = "No Content"; -- RFC9110, Section 15.3.5
+ [205] = "Reset Content"; -- RFC9110, Section 15.3.6
+ [206] = "Partial Content"; -- RFC9110, Section 15.3.7
[207] = "Multi-Status";
[208] = "Already Reported";
-- [209-225] = "Unassigned";
[226] = "IM Used";
-- [227-299] = "Unassigned";
- [300] = "Multiple Choices"; -- RFC7231, Section 6.4.1
- [301] = "Moved Permanently"; -- RFC7231, Section 6.4.2
- [302] = "Found"; -- RFC7231, Section 6.4.3
- [303] = "See Other"; -- RFC7231, Section 6.4.4
- [304] = "Not Modified"; -- RFC7232, Section 4.1
- [305] = "Use Proxy"; -- RFC7231, Section 6.4.5
- -- [306] = "(Unused)"; -- RFC7231, Section 6.4.6
- [307] = "Temporary Redirect"; -- RFC7231, Section 6.4.7
- [308] = "Permanent Redirect";
+ [300] = "Multiple Choices"; -- RFC9110, Section 15.4.1
+ [301] = "Moved Permanently"; -- RFC9110, Section 15.4.2
+ [302] = "Found"; -- RFC9110, Section 15.4.3
+ [303] = "See Other"; -- RFC9110, Section 15.4.4
+ [304] = "Not Modified"; -- RFC9110, Section 15.4.5
+ [305] = "Use Proxy"; -- RFC9110, Section 15.4.6
+ -- [306] = "(Unused)"; -- RFC9110, Section 15.4.7
+ [307] = "Temporary Redirect"; -- RFC9110, Section 15.4.8
+ [308] = "Permanent Redirect"; -- RFC9110, Section 15.4.9
-- [309-399] = "Unassigned";
- [400] = "Bad Request"; -- RFC7231, Section 6.5.1
- [401] = "Unauthorized"; -- RFC7235, Section 3.1
- [402] = "Payment Required"; -- RFC7231, Section 6.5.2
- [403] = "Forbidden"; -- RFC7231, Section 6.5.3
- [404] = "Not Found"; -- RFC7231, Section 6.5.4
- [405] = "Method Not Allowed"; -- RFC7231, Section 6.5.5
- [406] = "Not Acceptable"; -- RFC7231, Section 6.5.6
- [407] = "Proxy Authentication Required"; -- RFC7235, Section 3.2
- [408] = "Request Timeout"; -- RFC7231, Section 6.5.7
- [409] = "Conflict"; -- RFC7231, Section 6.5.8
- [410] = "Gone"; -- RFC7231, Section 6.5.9
- [411] = "Length Required"; -- RFC7231, Section 6.5.10
- [412] = "Precondition Failed"; -- RFC7232, Section 4.2
- [413] = "Payload Too Large"; -- RFC7231, Section 6.5.11
- [414] = "URI Too Long"; -- RFC7231, Section 6.5.12
- [415] = "Unsupported Media Type"; -- RFC7231, Section 6.5.13
- [416] = "Range Not Satisfiable"; -- RFC7233, Section 4.4
- [417] = "Expectation Failed"; -- RFC7231, Section 6.5.14
+ [400] = "Bad Request"; -- RFC9110, Section 15.5.1
+ [401] = "Unauthorized"; -- RFC9110, Section 15.5.2
+ [402] = "Payment Required"; -- RFC9110, Section 15.5.3
+ [403] = "Forbidden"; -- RFC9110, Section 15.5.4
+ [404] = "Not Found"; -- RFC9110, Section 15.5.5
+ [405] = "Method Not Allowed"; -- RFC9110, Section 15.5.6
+ [406] = "Not Acceptable"; -- RFC9110, Section 15.5.7
+ [407] = "Proxy Authentication Required"; -- RFC9110, Section 15.5.8
+ [408] = "Request Timeout"; -- RFC9110, Section 15.5.9
+ [409] = "Conflict"; -- RFC9110, Section 15.5.10
+ [410] = "Gone"; -- RFC9110, Section 15.5.11
+ [411] = "Length Required"; -- RFC9110, Section 15.5.12
+ [412] = "Precondition Failed"; -- RFC9110, Section 15.5.13
+ [413] = "Content Too Large"; -- RFC9110, Section 15.5.14
+ [414] = "URI Too Long"; -- RFC9110, Section 15.5.15
+ [415] = "Unsupported Media Type"; -- RFC9110, Section 15.5.16
+ [416] = "Range Not Satisfiable"; -- RFC9110, Section 15.5.17
+ [417] = "Expectation Failed"; -- RFC9110, Section 15.5.18
[418] = "I'm a teapot"; -- RFC2324, Section 2.3.2
-- [419-420] = "Unassigned";
- [421] = "Misdirected Request"; -- RFC7540, Section 9.1.2
- [422] = "Unprocessable Entity";
+ [421] = "Misdirected Request"; -- RFC9110, Section 15.5.20
+ [422] = "Unprocessable Content"; -- RFC9110, Section 15.5.21
[423] = "Locked";
[424] = "Failed Dependency";
[425] = "Too Early";
- [426] = "Upgrade Required"; -- RFC7231, Section 6.5.15
+ [426] = "Upgrade Required"; -- RFC9110, Section 15.5.22
-- [427] = "Unassigned";
[428] = "Precondition Required";
[429] = "Too Many Requests";
@@ -67,17 +67,17 @@ local response_codes = {
[451] = "Unavailable For Legal Reasons";
-- [452-499] = "Unassigned";
- [500] = "Internal Server Error"; -- RFC7231, Section 6.6.1
- [501] = "Not Implemented"; -- RFC7231, Section 6.6.2
- [502] = "Bad Gateway"; -- RFC7231, Section 6.6.3
- [503] = "Service Unavailable"; -- RFC7231, Section 6.6.4
- [504] = "Gateway Timeout"; -- RFC7231, Section 6.6.5
- [505] = "HTTP Version Not Supported"; -- RFC7231, Section 6.6.6
+ [500] = "Internal Server Error"; -- RFC9110, Section 15.6.1
+ [501] = "Not Implemented"; -- RFC9110, Section 15.6.2
+ [502] = "Bad Gateway"; -- RFC9110, Section 15.6.3
+ [503] = "Service Unavailable"; -- RFC9110, Section 15.6.4
+ [504] = "Gateway Timeout"; -- RFC9110, Section 15.6.5
+ [505] = "HTTP Version Not Supported"; -- RFC9110, Section 15.6.6
[506] = "Variant Also Negotiates";
[507] = "Insufficient Storage";
[508] = "Loop Detected";
-- [509] = "Unassigned";
- [510] = "Not Extended";
+ [510] = "Not Extended"; -- (OBSOLETED)
[511] = "Network Authentication Required";
-- [512-599] = "Unassigned";
};
diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua
index 305bce76..15338ff4 100644
--- a/net/resolvers/basic.lua
+++ b/net/resolvers/basic.lua
@@ -2,13 +2,59 @@ local adns = require "net.adns";
local inet_pton = require "util.net".pton;
local inet_ntop = require "util.net".ntop;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
-local unpack = table.unpack or unpack; -- luacheck: ignore 113
+local promise = require "util.promise";
+local t_move = require "util.table".move;
local methods = {};
local resolver_mt = { __index = methods };
-- FIXME RFC 6724
+local function do_dns_lookup(self, dns_resolver, record_type, name)
+ return promise.new(function (resolve, reject)
+ local ipv = (record_type == "A" and "4") or (record_type == "AAAA" and "6") or nil;
+ if ipv and self.extra["use_ipv"..ipv] == false then
+ return reject(("IPv%s disabled - %s lookup skipped"):format(ipv, record_type));
+ elseif record_type == "TLSA" and self.extra.use_dane ~= true then
+ return reject("DANE disabled - TLSA lookup skipped");
+ end
+ dns_resolver:lookup(function (answer, err)
+ if not answer then
+ return reject(err);
+ elseif answer.bogus then
+ return reject(("Validation error in %s lookup"):format(record_type));
+ elseif answer.status and #answer == 0 then
+ return reject(("%s in %s lookup"):format(answer.status, record_type));
+ end
+
+ local targets = { secure = answer.secure };
+ for _, record in ipairs(answer) do
+ if ipv then
+ table.insert(targets, { self.conn_type..ipv, record[record_type:lower()], self.port, self.extra });
+ else
+ table.insert(targets, record[record_type:lower()]);
+ end
+ end
+ return resolve(targets);
+ end, name, record_type, "IN");
+ end);
+end
+
+local function merge_targets(ipv4_targets, ipv6_targets)
+ local result = { secure = ipv4_targets.secure and ipv6_targets.secure };
+ local common_length = math.min(#ipv4_targets, #ipv6_targets);
+ for i = 1, common_length do
+ table.insert(result, ipv6_targets[i]);
+ table.insert(result, ipv4_targets[i]);
+ end
+ if common_length < #ipv4_targets then
+ t_move(ipv4_targets, common_length+1, #ipv4_targets, common_length+1, result);
+ elseif common_length < #ipv6_targets then
+ t_move(ipv6_targets, common_length+1, #ipv6_targets, common_length+1, result);
+ end
+ return result;
+end
+
-- Find the next target to connect to, and
-- pass it to cb()
function methods:next(cb)
@@ -18,7 +64,7 @@ function methods:next(cb)
return;
end
local next_target = table.remove(self.targets, 1);
- cb(unpack(next_target, 1, 4));
+ cb(next_target[1], next_target[2], next_target[3], next_target[4], not not self.targets[1]);
return;
end
@@ -28,91 +74,45 @@ function methods:next(cb)
return;
end
- local secure = true;
- local tlsa = {};
- local targets = {};
- local n = 3;
- local function ready()
- n = n - 1;
- if n > 0 then return; end
- self.targets = targets;
+ -- Resolve DNS to target list
+ local dns_resolver = adns.resolver();
+
+ local dns_lookups = {
+ ipv4 = do_dns_lookup(self, dns_resolver, "A", self.hostname);
+ ipv6 = do_dns_lookup(self, dns_resolver, "AAAA", self.hostname);
+ tlsa = do_dns_lookup(self, dns_resolver, "TLSA", ("_%d._%s.%s"):format(self.port, self.conn_type, self.hostname));
+ };
+
+ promise.all_settled(dns_lookups):next(function (dns_results)
+ -- Combine targets, assign to self.targets, self:next(cb)
+ local have_ipv4 = dns_results.ipv4.status == "fulfilled";
+ local have_ipv6 = dns_results.ipv6.status == "fulfilled";
+
+ if have_ipv4 and have_ipv6 then
+ self.targets = merge_targets(dns_results.ipv4.value, dns_results.ipv6.value);
+ elseif have_ipv4 then
+ self.targets = dns_results.ipv4.value;
+ elseif have_ipv6 then
+ self.targets = dns_results.ipv6.value;
+ else
+ self.targets = {};
+ end
+
if self.extra and self.extra.use_dane then
- if secure and tlsa[1] then
- self.extra.tlsa = tlsa;
+ if self.targets.secure and dns_results.tlsa.status == "fulfilled" then
+ self.extra.tlsa = dns_results.tlsa.value;
self.extra.dane_hostname = self.hostname;
else
self.extra.tlsa = nil;
self.extra.dane_hostname = nil;
end
end
- self:next(cb);
- end
- -- Resolve DNS to target list
- local dns_resolver = adns.resolver();
-
- if not self.extra or self.extra.use_ipv4 ~= false then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."4", record.a, self.port, self.extra });
- end
- if answer.bogus then
- self.last_error = "Validation error in A lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in A lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, self.hostname, "A", "IN");
- else
- ready();
- end
-
- if not self.extra or self.extra.use_ipv6 ~= false then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra });
- end
- if answer.bogus then
- self.last_error = "Validation error in AAAA lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in AAAA lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, self.hostname, "AAAA", "IN");
- else
- ready();
- end
-
- if self.extra and self.extra.use_dane == true then
- dns_resolver:lookup(function (answer, err)
- if answer then
- secure = secure and answer.secure;
- for _, record in ipairs(answer) do
- table.insert(tlsa, record.tlsa);
- end
- if answer.bogus then
- self.last_error = "Validation error in TLSA lookup";
- elseif answer.status then
- self.last_error = answer.status .. " in TLSA lookup";
- end
- else
- self.last_error = err;
- end
- ready();
- end, ("_%d._tcp.%s"):format(self.port, self.hostname), "TLSA", "IN");
- else
- ready();
- end
+ self:next(cb);
+ end):catch(function (err)
+ self.last_error = err;
+ self.targets = {};
+ end);
end
local function new(hostname, port, conn_type, extra)
@@ -137,7 +137,7 @@ local function new(hostname, port, conn_type, extra)
hostname = ascii_host;
port = port;
conn_type = conn_type;
- extra = extra;
+ extra = extra or {};
targets = targets;
}, resolver_mt);
end
diff --git a/net/resolvers/service.lua b/net/resolvers/service.lua
index 3810cac8..a7ce76a3 100644
--- a/net/resolvers/service.lua
+++ b/net/resolvers/service.lua
@@ -2,23 +2,78 @@ local adns = require "net.adns";
local basic = require "net.resolvers.basic";
local inet_pton = require "util.net".pton;
local idna_to_ascii = require "util.encodings".idna.to_ascii;
-local unpack = table.unpack or unpack; -- luacheck: ignore 113
local methods = {};
local resolver_mt = { __index = methods };
+local function new_target_selector(rrset)
+ local rr_count = rrset and #rrset;
+ if not rr_count or rr_count == 0 then
+ rrset = nil;
+ else
+ table.sort(rrset, function (a, b) return a.srv.priority < b.srv.priority end);
+ end
+ local rrset_pos = 1;
+ local priority_bucket, bucket_total_weight, bucket_len, bucket_used;
+ return function ()
+ if not rrset then return; end
+
+ if not priority_bucket or bucket_used >= bucket_len then
+ if rrset_pos > rr_count then return; end -- Used up all records
+
+ -- Going to start on a new priority now. Gather up all the next
+ -- records with the same priority and add them to priority_bucket
+ priority_bucket, bucket_total_weight, bucket_len, bucket_used = {}, 0, 0, 0;
+ local current_priority;
+ repeat
+ local curr_record = rrset[rrset_pos].srv;
+ if not current_priority then
+ current_priority = curr_record.priority;
+ elseif current_priority ~= curr_record.priority then
+ break;
+ end
+ table.insert(priority_bucket, curr_record);
+ bucket_total_weight = bucket_total_weight + curr_record.weight;
+ bucket_len = bucket_len + 1;
+ rrset_pos = rrset_pos + 1;
+ until rrset_pos > rr_count;
+ end
+
+ bucket_used = bucket_used + 1;
+ local n, running_total = math.random(0, bucket_total_weight), 0;
+ local target_record;
+ for i = 1, bucket_len do
+ local candidate = priority_bucket[i];
+ if candidate then
+ running_total = running_total + candidate.weight;
+ if running_total >= n then
+ target_record = candidate;
+ bucket_total_weight = bucket_total_weight - candidate.weight;
+ priority_bucket[i] = nil;
+ break;
+ end
+ end
+ end
+ return target_record;
+ end;
+end
+
-- Find the next target to connect to, and
-- pass it to cb()
function methods:next(cb)
- if self.targets then
- if not self.resolver then
- if #self.targets == 0 then
+ if self.resolver or self._get_next_target then
+ if not self.resolver then -- Do we have a basic resolver currently?
+ -- We don't, so fetch a new SRV target, create a new basic resolver for it
+ local next_srv_target = self._get_next_target and self._get_next_target();
+ if not next_srv_target then
+ -- No more SRV targets left
cb(nil);
return;
end
- local next_target = table.remove(self.targets, 1);
- self.resolver = basic.new(unpack(next_target, 1, 4));
+ -- Create a new basic resolver for this SRV target
+ self.resolver = basic.new(next_srv_target.target, next_srv_target.port, self.conn_type, self.extra);
end
+ -- Look up the next (basic) target from the current target's resolver
self.resolver:next(function (...)
if self.resolver then
self.last_error = self.resolver.last_error;
@@ -31,6 +86,9 @@ function methods:next(cb)
end
end);
return;
+ elseif self.in_progress then
+ cb(nil);
+ return;
end
if not self.hostname then
@@ -39,9 +97,9 @@ function methods:next(cb)
return;
end
- local targets = {};
+ self.in_progress = true;
+
local function ready()
- self.targets = targets;
self:next(cb);
end
@@ -63,7 +121,7 @@ function methods:next(cb)
if #answer == 0 then
if self.extra and self.extra.default_port then
- table.insert(targets, { self.hostname, self.extra.default_port, self.conn_type, self.extra });
+ self.resolver = basic.new(self.hostname, self.extra.default_port, self.conn_type, self.extra);
else
self.last_error = "zero SRV records found";
end
@@ -77,10 +135,7 @@ function methods:next(cb)
return;
end
- table.sort(answer, function (a, b) return a.srv.priority < b.srv.priority end);
- for _, record in ipairs(answer) do
- table.insert(targets, { record.srv.target, record.srv.port, self.conn_type, self.extra });
- end
+ self._get_next_target = new_target_selector(answer);
else
self.last_error = err;
end
diff --git a/net/server.lua b/net/server.lua
index 0696fd52..72272bef 100644
--- a/net/server.lua
+++ b/net/server.lua
@@ -118,6 +118,13 @@ if prosody and set_config then
prosody.events.add_handler("config-reloaded", load_config);
end
+local tls_builder = server.tls_builder;
+-- resolving the basedir here avoids util.sslconfig depending on
+-- prosody.paths.config
+function server.tls_builder()
+ return tls_builder(prosody.paths.config or "")
+end
+
-- require "net.server" shall now forever return this,
-- ie. server_select or server_event as chosen above.
return server;
diff --git a/net/server_epoll.lua b/net/server_epoll.lua
index fa275d71..a546e1e3 100644
--- a/net/server_epoll.lua
+++ b/net/server_epoll.lua
@@ -18,7 +18,6 @@ local traceback = debug.traceback;
local logger = require "util.logger";
local log = logger.init("server_epoll");
local socket = require "socket";
-local luasec = require "ssl";
local realtime = require "util.time".now;
local monotonic = require "util.time".monotonic;
local indexedbheap = require "util.indexedbheap";
@@ -28,6 +27,8 @@ local inet_pton = inet.pton;
local _SOCKETINVALID = socket._SOCKETINVALID or -1;
local new_id = require "util.id".short;
local xpcall = require "util.xpcall".xpcall;
+local sslconfig = require "util.sslconfig";
+local tls_impl = require "net.tls_luasec";
local poller = require "util.poll"
local EEXIST = poller.EEXIST;
@@ -91,6 +92,12 @@ local default_config = { __index = {
--- How long to wait after getting the shutdown signal before forcefully tearing down every socket
shutdown_deadline = 5;
+
+ -- TCP Fast Open
+ tcp_fastopen = false;
+
+ -- Defer accept until incoming data is available
+ tcp_defer_accept = false;
}};
local cfg = default_config.__index;
@@ -614,6 +621,34 @@ function interface:set_sslctx(sslctx)
self._sslctx = sslctx;
end
+function interface:sslctx()
+ return self.tls_ctx
+end
+
+function interface:ssl_info()
+ local sock = self.conn;
+ if not sock.info then return nil, "not-implemented"; end
+ return sock:info();
+end
+
+function interface:ssl_peercertificate()
+ local sock = self.conn;
+ if not sock.getpeercertificate then return nil, "not-implemented"; end
+ return sock:getpeercertificate();
+end
+
+function interface:ssl_peerverification()
+ local sock = self.conn;
+ if not sock.getpeerverification then return nil, { { "Chain verification not supported" } }; end
+ return sock:getpeerverification();
+end
+
+function interface:ssl_peerfinished()
+ local sock = self.conn;
+ if not sock.getpeerfinished then return nil, "not-implemented"; end
+ return sock:getpeerfinished();
+end
+
function interface:starttls(tls_ctx)
if tls_ctx then self.tls_ctx = tls_ctx; end
self.starttls = false;
@@ -641,11 +676,7 @@ function interface:inittls(tls_ctx, now)
self.starttls = false;
self:debug("Starting TLS now");
self:updatenames(); -- Can't getpeer/sockname after wrap()
- local ok, conn, err = pcall(luasec.wrap, self.conn, self.tls_ctx);
- if not ok then
- conn, err = ok, conn;
- self:debug("Failed to initialize TLS: %s", err);
- end
+ local conn, err = self.tls_ctx:wrap(self.conn);
if not conn then
self:on("disconnect", err);
self:destroy();
@@ -656,8 +687,8 @@ function interface:inittls(tls_ctx, now)
if conn.sni then
if self.servername then
conn:sni(self.servername);
- elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
- conn:sni(self._server.hosts, true);
+ elseif next(self.tls_ctx._sni_contexts) ~= nil then
+ conn:sni(self.tls_ctx._sni_contexts, true);
end
end
if self.extra and self.extra.tlsa and conn.settlsa then
@@ -741,7 +772,6 @@ local function wrapsocket(client, server, read_size, listeners, tls_ctx, extra)
end
end
- conn:updatenames();
return conn;
end
@@ -767,6 +797,7 @@ function interface:onacceptable()
return;
end
local client = wrapsocket(conn, self, nil, self.listeners);
+ client:updatenames();
client:debug("New connection %s on server %s", client, self);
client:defaultoptions();
client._writable = cfg.opportunistic_writes;
@@ -885,6 +916,12 @@ local function wrapserver(conn, addr, port, listeners, config)
log = logger.init(("serv%s"):format(new_id()));
}, interface_mt);
server:debug("Server %s created", server);
+ if cfg.tcp_fastopen then
+ server:setoption("tcp-fastopen", cfg.tcp_fastopen);
+ end
+ if type(cfg.tcp_defer_accept) == "number" then
+ server:setoption("tcp-defer-accept", cfg.tcp_defer_accept);
+ end
server:add(true, false);
return server;
end
@@ -908,6 +945,7 @@ end
-- COMPAT
local function wrapclient(conn, addr, port, listeners, read_size, tls_ctx, extra)
local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra);
+ client:updatenames();
if not client.peername then
client.peername, client.peerport = addr, port;
end
@@ -941,9 +979,13 @@ local function addclient(addr, port, listeners, read_size, tls_ctx, typ, extra)
if not conn then return conn, err; end
local ok, err = conn:settimeout(0);
if not ok then return ok, err; end
+ local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra)
+ if cfg.tcp_fastopen then
+ client:setoption("tcp-fastopen-connect", 1);
+ end
local ok, err = conn:setpeername(addr, port);
if not ok and err ~= "timeout" then return ok, err; end
- local client = wrapsocket(conn, nil, read_size, listeners, tls_ctx, extra)
+ client:updatenames();
local ok, err = client:init();
if not client.peername then
-- otherwise not set until connected
@@ -1085,6 +1127,10 @@ return {
cfg = setmetatable(newconfig, default_config);
end;
+ tls_builder = function(basedir)
+ return sslconfig._new(tls_impl.new_context, basedir)
+ end,
+
-- libevent emulation
event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 };
addevent = function (fd, mode, callback)
diff --git a/net/server_event.lua b/net/server_event.lua
index c30181b8..d8f08c8d 100644
--- a/net/server_event.lua
+++ b/net/server_event.lua
@@ -47,11 +47,13 @@ local s_sub = string.sub
local coroutine_wrap = coroutine.wrap
local coroutine_yield = coroutine.yield
-local has_luasec, ssl = pcall ( require , "ssl" )
+local has_luasec = pcall ( require , "ssl" )
local socket = require "socket"
local levent = require "luaevent.core"
local inet = require "util.net";
local inet_pton = inet.pton;
+local sslconfig = require "util.sslconfig";
+local tls_impl = require "net.tls_luasec";
local socket_gettime = socket.gettime
@@ -153,7 +155,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed
_ = self.eventwrite and self.eventwrite:close( )
self.eventread, self.eventwrite = nil, nil
local err
- self.conn, err = ssl.wrap( self.conn, self._sslctx )
+ self.conn, err = self._sslctx:wrap(self.conn)
if err then
self.fatalerror = err
self.conn = nil -- cannot be used anymore
@@ -168,8 +170,8 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed
if self.conn.sni then
if self.servername then
self.conn:sni(self.servername);
- elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
- self.conn:sni(self._server.hosts, true);
+ elseif next(self._sslctx._sni_contexts) ~= nil then
+ self.conn:sni(self._sslctx._sni_contexts, true);
end
end
@@ -274,6 +276,34 @@ function interface_mt:pause()
return self:_lock(self.nointerface, true, self.nowriting);
end
+function interface_mt:sslctx()
+ return self._sslctx
+end
+
+function interface_mt:ssl_info()
+ local sock = self.conn;
+ if not sock.info then return nil, "not-implemented"; end
+ return sock:info();
+end
+
+function interface_mt:ssl_peercertificate()
+ local sock = self.conn;
+ if not sock.getpeercertificate then return nil, "not-implemented"; end
+ return sock:getpeercertificate();
+end
+
+function interface_mt:ssl_peerverification()
+ local sock = self.conn;
+ if not sock.getpeerverification then return nil, { { "Chain verification not supported" } }; end
+ return sock:getpeerverification();
+end
+
+function interface_mt:ssl_peerfinished()
+ local sock = self.conn;
+ if not sock.getpeerfinished then return nil, "not-implemented"; end
+ return sock:getpeerfinished();
+end
+
function interface_mt:resume()
self:_lock(self.nointerface, false, self.nowriting);
if self.readcallback and not self.eventread then
@@ -924,6 +954,10 @@ return {
add_task = add_task,
watchfd = watchfd,
+ tls_builder = function(basedir)
+ return sslconfig._new(tls_impl.new_context, basedir)
+ end,
+
__NAME = SCRIPT_NAME,
__DATE = LAST_MODIFIED,
__AUTHOR = SCRIPT_AUTHOR,
diff --git a/net/server_select.lua b/net/server_select.lua
index eea850ce..651bdfde 100644
--- a/net/server_select.lua
+++ b/net/server_select.lua
@@ -47,15 +47,15 @@ local coroutine_yield = coroutine.yield
--// extern libs //--
-local has_luasec, luasec = pcall ( require , "ssl" )
local luasocket = use "socket" or require "socket"
local luasocket_gettime = luasocket.gettime
local inet = require "util.net";
local inet_pton = inet.pton;
+local sslconfig = require "util.sslconfig";
+local has_luasec, tls_impl = pcall(require, "net.tls_luasec");
--// extern lib methods //--
-local ssl_wrap = ( has_luasec and luasec.wrap )
local socket_bind = luasocket.bind
local socket_select = luasocket.select
@@ -359,6 +359,21 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
handler.sslctx = function ( )
return sslctx
end
+ handler.ssl_info = function( )
+ return socket.info and socket:info()
+ end
+ handler.ssl_peercertificate = function( )
+ if not socket.getpeercertificate then return nil, "not-implemented"; end
+ return socket:getpeercertificate()
+ end
+ handler.ssl_peerverification = function( )
+ if not socket.getpeerverification then return nil, { { "Chain verification not supported" } }; end
+ return socket:getpeerverification();
+ end
+ handler.ssl_peerfinished = function( )
+ if not socket.getpeerfinished then return nil, "not-implemented"; end
+ return socket:getpeerfinished();
+ end
handler.send = function( _, data, i, j )
return send( socket, data, i, j )
end
@@ -652,7 +667,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
end
out_put( "server.lua: attempting to start tls on " .. tostring( socket ) )
local oldsocket, err = socket
- socket, err = ssl_wrap( socket, sslctx ) -- wrap socket
+ socket, err = sslctx:wrap(socket) -- wrap socket
if not socket then
out_put( "server.lua: error while starting tls on client: ", tostring(err or "unknown error") )
@@ -662,8 +677,8 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport
if socket.sni then
if self.servername then
socket:sni(self.servername);
- elseif self._server and type(self._server.hosts) == "table" and next(self._server.hosts) ~= nil then
- socket:sni(self.server().hosts, true);
+ elseif next(sslctx._sni_contexts) ~= nil then
+ socket:sni(sslctx._sni_contexts, true);
end
end
@@ -1169,4 +1184,8 @@ return {
removeserver = removeserver,
get_backend = get_backend,
changesettings = changesettings,
+
+ tls_builder = function(basedir)
+ return sslconfig._new(tls_impl.new_context, basedir)
+ end,
}
diff --git a/net/tls_luasec.lua b/net/tls_luasec.lua
new file mode 100644
index 00000000..2bedb5ab
--- /dev/null
+++ b/net/tls_luasec.lua
@@ -0,0 +1,89 @@
+-- Prosody IM
+-- Copyright (C) 2021 Prosody folks
+--
+-- This project is MIT/X11 licensed. Please see the
+-- COPYING file in the source package for more information.
+--
+
+--[[
+This file provides a shim abstraction over LuaSec, consolidating some code
+which was previously spread between net.server backends, portmanager and
+certmanager.
+
+The goal is to provide a more or less well-defined API on top of LuaSec which
+abstracts away some of the things which are not needed and simplifies usage of
+commonly used things (such as SNI contexts). Eventually, network backends
+which do not rely on LuaSocket+LuaSec should be able to provide *this* API
+instead of having to mimic LuaSec.
+]]
+local ssl = require "ssl";
+local ssl_newcontext = ssl.newcontext;
+local ssl_context = ssl.context or require "ssl.context";
+local io_open = io.open;
+
+local context_api = {};
+local context_mt = {__index = context_api};
+
+function context_api:set_sni_host(host, cert, key)
+ local ctx, err = self._builder:clone():apply({
+ certificate = cert,
+ key = key,
+ }):build();
+ if not ctx then
+ return false, err
+ end
+
+ self._sni_contexts[host] = ctx._inner
+
+ return true, nil
+end
+
+function context_api:remove_sni_host(host)
+ self._sni_contexts[host] = nil
+end
+
+function context_api:wrap(sock)
+ local ok, conn, err = pcall(ssl.wrap, sock, self._inner);
+ if not ok then
+ return nil, err
+ end
+ return conn, nil
+end
+
+local function new_context(cfg, builder)
+ -- LuaSec expects dhparam to be a callback that takes two arguments.
+ -- We ignore those because it is mostly used for having a separate
+ -- set of params for EXPORT ciphers, which we don't have by default.
+ if type(cfg.dhparam) == "string" then
+ local f, err = io_open(cfg.dhparam);
+ if not f then return nil, "Could not open DH parameters: "..err end
+ local dhparam = f:read("*a");
+ f:close();
+ cfg.dhparam = function() return dhparam; end
+ end
+
+ local inner, err = ssl_newcontext(cfg);
+ if not inner then
+ return nil, err
+ end
+
+ -- COMPAT Older LuaSec ignores the cipher list from the config, so we have to take care
+ -- of it ourselves (W/A for #x)
+ if inner and cfg.ciphers then
+ local success;
+ success, err = ssl_context.setcipher(inner, cfg.ciphers);
+ if not success then
+ return nil, err
+ end
+ end
+
+ return setmetatable({
+ _inner = inner,
+ _builder = builder,
+ _sni_contexts = {},
+ }, context_mt), nil
+end
+
+return {
+ new_context = new_context,
+};
diff --git a/plugins/adhoc/adhoc.lib.lua b/plugins/adhoc/adhoc.lib.lua
index 4cf6911d..eb91f252 100644
--- a/plugins/adhoc/adhoc.lib.lua
+++ b/plugins/adhoc/adhoc.lib.lua
@@ -34,6 +34,8 @@ function _M.handle_cmd(command, origin, stanza)
local cmdtag = stanza.tags[1]
local sessionid = cmdtag.attr.sessionid or uuid.generate();
local dataIn = {
+ origin = origin;
+ stanza = stanza;
to = stanza.attr.to;
from = stanza.attr.from;
action = cmdtag.attr.action or "execute";
diff --git a/plugins/adhoc/mod_adhoc.lua b/plugins/adhoc/mod_adhoc.lua
index 09a72075..9d6ff77a 100644
--- a/plugins/adhoc/mod_adhoc.lua
+++ b/plugins/adhoc/mod_adhoc.lua
@@ -79,12 +79,12 @@ module:hook("iq-set/host/"..xmlns_cmd..":command", function (event)
or (command.permission == "global_admin" and not global_admin)
or (command.permission == "local_user" and hostname ~= module.host) then
origin.send(st.error_reply(stanza, "auth", "forbidden", "You don't have permission to execute this command"):up()
- :add_child(commands[node]:cmdtag("canceled")
+ :add_child(command:cmdtag("canceled")
:tag("note", {type="error"}):text("You don't have permission to execute this command")));
return true
end
-- User has permission now execute the command
- adhoc_handle_cmd(commands[node], origin, stanza);
+ adhoc_handle_cmd(command, origin, stanza);
return true;
end
end, 500);
diff --git a/plugins/mod_admin_shell.lua b/plugins/mod_admin_shell.lua
index 94530cf1..4033d868 100644
--- a/plugins/mod_admin_shell.lua
+++ b/plugins/mod_admin_shell.lua
@@ -36,6 +36,7 @@ local serialization = require "util.serialization";
local serialize_config = serialization.new ({ fatal = false, unquoted = true});
local time = require "util.time";
local promise = require "util.promise";
+local logger = require "util.logger";
local t_insert = table.insert;
local t_concat = table.concat;
@@ -83,8 +84,8 @@ function runner_callbacks:error(err)
self.data.print("Error: "..tostring(err));
end
-local function send_repl_output(session, line)
- return session.send(st.stanza("repl-output"):text(tostring(line)));
+local function send_repl_output(session, line, attr)
+ return session.send(st.stanza("repl-output", attr):text(tostring(line)));
end
function console:new_session(admin_session)
@@ -99,8 +100,14 @@ function console:new_session(admin_session)
end
return send_repl_output(admin_session, table.concat(t, "\t"));
end;
+ write = function (t)
+ return send_repl_output(admin_session, t, { eol = "0" });
+ end;
serialize = tostring;
disconnect = function () admin_session:close(); end;
+ is_connected = function ()
+ return not not admin_session.conn;
+ end
};
session.env = setmetatable({}, default_env_mt);
@@ -126,6 +133,11 @@ local function handle_line(event)
session = console:new_session(event.origin);
event.origin.shell_session = session;
end
+
+ local default_width = 132; -- The common default of 80 is a bit too narrow for e.g. s2s:show(), 132 was another common width for hardware terminals
+ local margin = 2; -- To account for '| ' when lines are printed
+ session.width = (tonumber(event.stanza.attr.width) or default_width)-margin;
+
local line = event.stanza:get_text();
local useglobalenv;
@@ -212,7 +224,7 @@ function commands.help(session, data)
print [[Commands are divided into multiple sections. For help on a particular section, ]]
print [[type: help SECTION (for example, 'help c2s'). Sections are: ]]
print [[]]
- local row = format_table({ { title = "Section"; width = 7 }; { title = "Description"; width = "100%" } })
+ local row = format_table({ { title = "Section", width = 7 }, { title = "Description", width = "100%" } }, session.width)
print(row())
print(row { "c2s"; "Commands to manage local client-to-server sessions" })
print(row { "s2s"; "Commands to manage sessions between this server and others" })
@@ -228,6 +240,7 @@ function commands.help(session, data)
print(row { "dns"; "Commands to manage and inspect the internal DNS resolver" })
print(row { "xmpp"; "Commands for sending XMPP stanzas" })
print(row { "debug"; "Commands for debugging the server" })
+ print(row { "watch"; "Commands for watching live logs from the server" })
print(row { "config"; "Reloading the configuration, etc." })
print(row { "columns"; "Information about customizing session listings" })
print(row { "console"; "Help regarding the console itself" })
@@ -304,6 +317,9 @@ function commands.help(session, data)
print [[debug:logevents(host) - Enable logging of fired events on host]]
print [[debug:events(host, event) - Show registered event handlers]]
print [[debug:timers() - Show information about scheduled timers]]
+ elseif section == "watch" then
+ print [[watch:log() - Follow debug logs]]
+ print [[watch:stanzas(target, filter) - Watch live stanzas matching the specified target and filter]]
elseif section == "console" then
print [[Hey! Welcome to Prosody's admin console.]]
print [[First thing, if you're ever wondering how to get out, simply type 'quit'.]]
@@ -334,7 +350,7 @@ function commands.help(session, data)
meta_columns[2].width = math.max(meta_columns[2].width or 0, #(spec.title or ""));
meta_columns[3].width = math.max(meta_columns[3].width or 0, #(spec.description or ""));
end
- local row = format_table(meta_columns, 120)
+ local row = format_table(meta_columns, session.width)
print(row());
for column, spec in iterators.sorted_pairs(available_columns) do
print(row({ column, spec.title, spec.description }));
@@ -480,6 +496,16 @@ function def_env.module:info(name, hosts)
local function item_name(item) return item.name; end
+ local function task_timefmt(t)
+ if not t then
+ return "no last run time"
+ elseif os.difftime(os.time(), t) < 86400 then
+ return os.date("last run today at %H:%M", t);
+ else
+ return os.date("last run %A at %H:%M", t);
+ end
+ end
+
local friendly_descriptions = {
["adhoc-provider"] = "Ad-hoc commands",
["auth-provider"] = "Authentication provider",
@@ -497,12 +523,22 @@ function def_env.module:info(name, hosts)
["auth-provider"] = item_name,
["storage-provider"] = item_name,
["http-provider"] = function(item, mod) return mod:http_url(item.name, item.default_path); end,
- ["net-provider"] = item_name,
+ ["net-provider"] = function(item)
+ local service_name = item.name;
+ local ports_list = {};
+ for _, interface, port in portmanager.get_active_services():iter(service_name, nil, nil) do
+ table.insert(ports_list, "["..interface.."]:"..port);
+ end
+ if not ports_list[1] then
+ return service_name..": not listening on any ports";
+ end
+ return service_name..": "..table.concat(ports_list, ", ");
+ end,
["measure"] = function(item) return item.name .. " (" .. suf(item.conf and item.conf.unit, " ") .. item.type .. ")"; end,
["metric"] = function(item)
return ("%s (%s%s)%s"):format(item.name, suf(item.mf.unit, " "), item.mf.type_, pre(": ", item.mf.description));
end,
- ["task"] = function (item) return string.format("%s (%s)", item.name or item.id, item.when); end
+ ["task"] = function (item) return string.format("%s (%s, %s)", item.name or item.id, item.when, task_timefmt(item.last)); end
};
for host in hosts do
@@ -800,9 +836,7 @@ available_columns = {
mapper = function(conn, session)
if not session.secure then return "insecure"; end
if not conn or not conn:ssl() then return "secure" end
- local sock = conn and conn:socket();
- if not sock then return "secure"; end
- local tls_info = sock.info and sock:info();
+ local tls_info = conn.ssl_info and conn:ssl_info();
return tls_info and tls_info.protocol or "secure";
end;
};
@@ -812,8 +846,7 @@ available_columns = {
width = 30;
key = "conn";
mapper = function(conn)
- local sock = conn:socket();
- local info = sock and sock.info and sock:info();
+ local info = conn and conn.ssl_info and conn:ssl_info();
if info then return info.cipher end
end;
};
@@ -931,7 +964,7 @@ end
function def_env.c2s:show(match_jid, colspec)
local print = self.session.print;
local columns = get_colspec(colspec, { "id"; "jid"; "ipv"; "status"; "secure"; "smacks"; "csi" });
- local row = format_table(columns, 120);
+ local row = format_table(columns, self.session.width);
local function match(session)
local jid = get_jid(session)
@@ -1014,7 +1047,7 @@ end
function def_env.s2s:show(match_jid, colspec)
local print = self.session.print;
local columns = get_colspec(colspec, { "id"; "host"; "dir"; "remote"; "ipv"; "secure"; "s2s_sasl"; "dialback" });
- local row = format_table(columns, 132);
+ local row = format_table(columns, self.session.width);
local function match(session)
local host, remote = get_s2s_hosts(session);
@@ -1500,7 +1533,7 @@ function def_env.xmpp:ping(localhost, remotehost, timeout)
module:unhook("s2sin-established", onestablished);
module:unhook("s2s-destroyed", ondestroyed);
end):next(function(pong)
- return ("pong from %s in %gs"):format(pong.stanza.attr.from, time.now() - time_start);
+ return ("pong from %s on %s in %gs"):format(pong.stanza.attr.from, pong.origin.id, time.now() - time_start);
end);
end
@@ -1552,7 +1585,7 @@ function def_env.http:list(hosts)
local output = format_table({
{ title = "Module", width = "20%" },
{ title = "URL", width = "80%" },
- }, 132);
+ }, self.session.width);
for _, host in ipairs(hosts) do
local http_apps = modulemanager.get_items("http-provider", host);
@@ -1583,6 +1616,60 @@ function def_env.http:list(hosts)
return true;
end
+def_env.watch = {};
+
+function def_env.watch:log()
+ local writing = false;
+ local sink = logger.add_simple_sink(function (source, level, message)
+ if writing then return; end
+ writing = true;
+ self.session.print(source, level, message);
+ writing = false;
+ end);
+
+ while self.session.is_connected() do
+ async.sleep(3);
+ end
+ if not logger.remove_sink(sink) then
+ module:log("warn", "Unable to remove watch:log() sink");
+ end
+end
+
+local stanza_watchers = module:require("mod_debug_stanzas/watcher");
+function def_env.watch:stanzas(target_spec, filter_spec)
+ local function handler(event_type, stanza, session)
+ if stanza then
+ if event_type == "sent" then
+ self.session.print(("\n<!-- sent to %s -->"):format(session.id));
+ elseif event_type == "received" then
+ self.session.print(("\n<!-- received from %s -->"):format(session.id));
+ else
+ self.session.print(("\n<!-- %s (%s) -->"):format(event_type, session.id));
+ end
+ self.session.print(stanza);
+ elseif session then
+ self.session.print("\n<!-- session "..session.id.." "..event_type.." -->");
+ elseif event_type then
+ self.session.print("\n<!-- "..event_type.." -->");
+ end
+ end
+
+ stanza_watchers.add({
+ target_spec = {
+ jid = target_spec;
+ };
+ filter_spec = filter_spec and {
+ with_jid = filter_spec;
+ };
+ }, handler);
+
+ while self.session.is_connected() do
+ async.sleep(3);
+ end
+
+ stanza_watchers.remove(handler);
+end
+
def_env.debug = {};
function def_env.debug:logevents(host)
@@ -1926,6 +2013,10 @@ function def_env.stats:show(name_filter)
end
+function module.unload()
+ stanza_watchers.cleanup();
+end
+
-------------
diff --git a/plugins/mod_c2s.lua b/plugins/mod_c2s.lua
index c8f54fa7..8c0844ae 100644
--- a/plugins/mod_c2s.lua
+++ b/plugins/mod_c2s.lua
@@ -117,8 +117,7 @@ function stream_callbacks._streamopened(session, attr)
session.secure = true;
session.encrypted = true;
- local sock = session.conn:socket();
- local info = sock.info and sock:info();
+ local info = session.conn:ssl_info();
if type(info) == "table" then
(session.log or log)("info", "Stream encrypted (%s with %s)", info.protocol, info.cipher);
session.compressed = info.compression;
@@ -295,8 +294,7 @@ function listener.onconnect(conn)
session.encrypted = true;
-- Check if TLS compression is used
- local sock = conn:socket();
- local info = sock.info and sock:info();
+ local info = conn:ssl_info();
if type(info) == "table" then
(session.log or log)("info", "Stream encrypted (%s with %s)", info.protocol, info.cipher);
session.compressed = info.compression;
diff --git a/plugins/mod_csi_simple.lua b/plugins/mod_csi_simple.lua
index 569916b0..b9a470f5 100644
--- a/plugins/mod_csi_simple.lua
+++ b/plugins/mod_csi_simple.lua
@@ -116,6 +116,9 @@ local flush_reasons = module:metric(
{ "reason" }
);
+local flush_sizes = module:metric("histogram", "flush_stanza_count", "", "Number of stanzas flushed at once", {},
+ { buckets = { 0, 1, 2, 4, 8, 16, 32, 64, 128, 256 } }):with_labels();
+
local function manage_buffer(stanza, session)
local ctr = session.csi_counter or 0;
if session.state ~= "inactive" then
@@ -129,6 +132,7 @@ local function manage_buffer(stanza, session)
session.csi_measure_buffer_hold = nil;
end
flush_reasons:with_labels(why or "important"):add(1);
+ flush_sizes:sample(ctr);
session.log("debug", "Flushing buffer (%s; queue size is %d)", why or "important", session.csi_counter);
session.state = "flushing";
module:fire_event("csi-flushing", { session = session });
@@ -147,6 +151,7 @@ local function flush_buffer(data, session)
session.log("debug", "Flushing buffer (%s; queue size is %d)", "client activity", session.csi_counter);
session.state = "flushing";
module:fire_event("csi-flushing", { session = session });
+ flush_sizes:sample(ctr);
flush_reasons:with_labels("client activity"):add(1);
if session.csi_measure_buffer_hold then
session.csi_measure_buffer_hold();
diff --git a/plugins/mod_debug_stanzas/watcher.lib.lua b/plugins/mod_debug_stanzas/watcher.lib.lua
new file mode 100644
index 00000000..e21fc946
--- /dev/null
+++ b/plugins/mod_debug_stanzas/watcher.lib.lua
@@ -0,0 +1,220 @@
+local filters = require "util.filters";
+local jid = require "util.jid";
+local set = require "util.set";
+
+local client_watchers = {};
+
+-- active_filters[session] = {
+-- filter_func = filter_func;
+-- downstream = { cb1, cb2, ... };
+-- }
+local active_filters = {};
+
+local function subscribe_session_stanzas(session, handler, reason)
+ if active_filters[session] then
+ table.insert(active_filters[session].downstream, handler);
+ if reason then
+ handler(reason, nil, session);
+ end
+ return;
+ end
+ local downstream = { handler };
+ active_filters[session] = {
+ filter_in = function (stanza)
+ module:log("debug", "NOTIFY WATCHER %d", #downstream);
+ for i = 1, #downstream do
+ downstream[i]("received", stanza, session);
+ end
+ return stanza;
+ end;
+ filter_out = function (stanza)
+ module:log("debug", "NOTIFY WATCHER %d", #downstream);
+ for i = 1, #downstream do
+ downstream[i]("sent", stanza, session);
+ end
+ return stanza;
+ end;
+ downstream = downstream;
+ };
+ filters.add_filter(session, "stanzas/in", active_filters[session].filter_in);
+ filters.add_filter(session, "stanzas/out", active_filters[session].filter_out);
+ if reason then
+ handler(reason, nil, session);
+ end
+end
+
+local function unsubscribe_session_stanzas(session, handler, reason)
+ local active_filter = active_filters[session];
+ if not active_filter then
+ return;
+ end
+ for i = #active_filter.downstream, 1, -1 do
+ if active_filter.downstream[i] == handler then
+ table.remove(active_filter.downstream, i);
+ if reason then
+ handler(reason, nil, session);
+ end
+ end
+ end
+ if #active_filter.downstream == 0 then
+ filters.remove_filter(session, "stanzas/in", active_filter.filter_in);
+ filters.remove_filter(session, "stanzas/out", active_filter.filter_out);
+ end
+ active_filters[session] = nil;
+end
+
+local function unsubscribe_all_from_session(session, reason)
+ local active_filter = active_filters[session];
+ if not active_filter then
+ return;
+ end
+ for i = #active_filter.downstream, 1, -1 do
+ local handler = table.remove(active_filter.downstream, i);
+ if reason then
+ handler(reason, nil, session);
+ end
+ end
+ filters.remove_filter(session, "stanzas/in", active_filter.filter_in);
+ filters.remove_filter(session, "stanzas/out", active_filter.filter_out);
+ active_filters[session] = nil;
+end
+
+local function unsubscribe_handler_from_all(handler, reason)
+ for session in pairs(active_filters) do
+ unsubscribe_session_stanzas(session, handler, reason);
+ end
+end
+
+local s2s_watchers = {};
+
+module:hook("s2sin-established", function (event)
+ for _, watcher in ipairs(s2s_watchers) do
+ if watcher.target_spec == event.session.from_host then
+ subscribe_session_stanzas(event.session, watcher.handler, "opened");
+ end
+ end
+end);
+
+module:hook("s2sout-established", function (event)
+ for _, watcher in ipairs(s2s_watchers) do
+ if watcher.target_spec == event.session.to_host then
+ subscribe_session_stanzas(event.session, watcher.handler, "opened");
+ end
+ end
+end);
+
+module:hook("s2s-closed", function (event)
+ unsubscribe_all_from_session(event.session, "closed");
+end);
+
+local watched_hosts = set.new();
+
+local handler_map = setmetatable({}, { __mode = "kv" });
+
+local function add_stanza_watcher(spec, orig_handler)
+ local function filtering_handler(event_type, stanza, session)
+ if stanza and spec.filter_spec then
+ if spec.filter_spec.with_jid then
+ if event_type == "sent" and (not stanza.attr.from or not jid.compare(stanza.attr.from, spec.filter_spec.with_jid)) then
+ return;
+ elseif event_type == "received" and (not stanza.attr.to or not jid.compare(stanza.attr.to, spec.filter_spec.with_jid)) then
+ return;
+ end
+ end
+ end
+ return orig_handler(event_type, stanza, session);
+ end
+ handler_map[orig_handler] = filtering_handler;
+ if spec.target_spec.jid then
+ local target_is_remote_host = not jid.node(spec.target_spec.jid) and not prosody.hosts[spec.target_spec.jid];
+
+ if target_is_remote_host then
+ -- Watch s2s sessions
+ table.insert(s2s_watchers, {
+ target_spec = spec.target_spec.jid;
+ handler = filtering_handler;
+ orig_handler = orig_handler;
+ });
+
+ -- Scan existing s2sin for matches
+ for session in pairs(prosody.incoming_s2s) do
+ if spec.target_spec.jid == session.from_host then
+ subscribe_session_stanzas(session, filtering_handler, "attached");
+ end
+ end
+ -- Scan existing s2sout for matches
+ for local_host, local_session in pairs(prosody.hosts) do --luacheck: ignore 213/local_host
+ for remote_host, remote_session in pairs(local_session.s2sout) do
+ if spec.target_spec.jid == remote_host then
+ subscribe_session_stanzas(remote_session, filtering_handler, "attached");
+ end
+ end
+ end
+ else
+ table.insert(client_watchers, {
+ target_spec = spec.target_spec.jid;
+ handler = filtering_handler;
+ orig_handler = orig_handler;
+ });
+ local host = jid.host(spec.target_spec.jid);
+ if not watched_hosts:contains(host) and prosody.hosts[host] then
+ module:context(host):hook("resource-bind", function (event)
+ for _, watcher in ipairs(client_watchers) do
+ module:log("debug", "NEW CLIENT: %s vs %s", event.session.full_jid, watcher.target_spec);
+ if jid.compare(event.session.full_jid, watcher.target_spec) then
+ module:log("debug", "MATCH");
+ subscribe_session_stanzas(event.session, watcher.handler, "opened");
+ else
+ module:log("debug", "NO MATCH");
+ end
+ end
+ end);
+
+ module:context(host):hook("resource-unbind", function (event)
+ unsubscribe_all_from_session(event.session, "closed");
+ end);
+
+ watched_hosts:add(host);
+ end
+ for full_jid, session in pairs(prosody.full_sessions) do
+ if jid.compare(full_jid, spec.target_spec.jid) then
+ subscribe_session_stanzas(session, filtering_handler, "attached");
+ end
+ end
+ end
+ else
+ error("No recognized target selector");
+ end
+end
+
+local function remove_stanza_watcher(orig_handler)
+ local handler = handler_map[orig_handler];
+ unsubscribe_handler_from_all(handler, "detached");
+ handler_map[orig_handler] = nil;
+
+ for i = #client_watchers, 1, -1 do
+ if client_watchers[i].orig_handler == orig_handler then
+ table.remove(client_watchers, i);
+ end
+ end
+
+ for i = #s2s_watchers, 1, -1 do
+ if s2s_watchers[i].orig_handler == orig_handler then
+ table.remove(s2s_watchers, i);
+ end
+ end
+end
+
+local function cleanup(reason)
+ client_watchers = {};
+ s2s_watchers = {};
+ for session in pairs(active_filters) do
+ unsubscribe_all_from_session(session, reason or "cancelled");
+ end
+end
+
+return {
+ add = add_stanza_watcher;
+ remove = remove_stanza_watcher;
+ cleanup = cleanup;
+};
diff --git a/plugins/mod_mam/mod_mam.lua b/plugins/mod_mam/mod_mam.lua
index 50095e2f..083ae90d 100644
--- a/plugins/mod_mam/mod_mam.lua
+++ b/plugins/mod_mam/mod_mam.lua
@@ -53,8 +53,12 @@ if not archive.find then
end
local use_total = module:get_option_boolean("mam_include_total", true);
-function schedule_cleanup()
- -- replaced later if cleanup is enabled
+function schedule_cleanup(_username, _date) -- luacheck: ignore 212
+ -- Called to make a note of which users have messages on which days, which in
+ -- turn is used to optimize the message expiry routine.
+ --
+ -- This noop is conditionally replaced later depending on retention settings
+ -- and storage backend capabilities.
end
-- Handle prefs.
diff --git a/plugins/mod_s2s.lua b/plugins/mod_s2s.lua
index e810c6cd..dd585ac7 100644
--- a/plugins/mod_s2s.lua
+++ b/plugins/mod_s2s.lua
@@ -146,17 +146,17 @@ local function bounce_sendq(session, reason)
elseif type(reason) == "string" then
reason_text = reason;
end
- for i, data in ipairs(sendq) do
- local reply = data[2];
- if reply and not(reply.attr.xmlns) and bouncy_stanzas[reply.name] then
- reply.attr.type = "error";
- reply:tag("error", {type = error_type, by = session.from_host})
- :tag(condition, {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"}):up();
- if reason_text then
- reply:tag("text", {xmlns = "urn:ietf:params:xml:ns:xmpp-stanzas"})
- :text("Server-to-server connection failed: "..reason_text):up();
- end
+ for i, stanza in ipairs(sendq) do
+ if not stanza.attr.xmlns and bouncy_stanzas[stanza.name] and stanza.attr.type ~= "error" and stanza.attr.type ~= "result" then
+ local reply = st.error_reply(
+ stanza,
+ error_type,
+ condition,
+ reason_text and ("Server-to-server connection failed: "..reason_text) or nil
+ );
core_process_stanza(dummy, reply);
+ else
+ (session.log or log)("debug", "Not eligible for bouncing, discarding %s", stanza:top_tag());
end
sendq[i] = nil;
end
@@ -182,15 +182,11 @@ function route_to_existing_session(event)
(host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host);
-- Queue stanza until we are able to send it
- local queued_item = {
- tostring(stanza),
- stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza);
- };
if host.sendq then
- t_insert(host.sendq, queued_item);
+ t_insert(host.sendq, st.clone(stanza));
else
-- luacheck: ignore 122
- host.sendq = { queued_item };
+ host.sendq = { st.clone(stanza) };
end
host.log("debug", "stanza [%s] queued ", stanza.name);
return true;
@@ -215,7 +211,7 @@ function route_to_new_session(event)
-- Store in buffer
host_session.bounce_sendq = bounce_sendq;
- host_session.sendq = { {tostring(stanza), stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza)} };
+ host_session.sendq = { st.clone(stanza) };
log("debug", "stanza [%s] queued until connection complete", stanza.name);
-- FIXME Cleaner solution to passing extra data from resolvers to net.server
-- This mt-clone allows resolvers to add extra data, currently used for DANE TLSA records
@@ -324,8 +320,8 @@ function mark_connected(session)
if sendq then
session.log("debug", "sending %d queued stanzas across new outgoing connection to %s", #sendq, session.to_host);
local send = session.sends2s;
- for i, data in ipairs(sendq) do
- send(data[1]);
+ for i, stanza in ipairs(sendq) do
+ send(stanza);
sendq[i] = nil;
end
session.sendq = nil;
@@ -389,10 +385,10 @@ end
--- Helper to check that a session peer's certificate is valid
local function check_cert_status(session)
local host = session.direction == "outgoing" and session.to_host or session.from_host
- local conn = session.conn:socket()
+ local conn = session.conn
local cert
- if conn.getpeercertificate then
- cert = conn:getpeercertificate()
+ if conn.ssl_peercertificate then
+ cert = conn:ssl_peercertificate()
end
return module:fire_event("s2s-check-certificate", { host = host, session = session, cert = cert });
@@ -404,8 +400,7 @@ local function session_secure(session)
session.secure = true;
session.encrypted = true;
- local sock = session.conn:socket();
- local info = sock.info and sock:info();
+ local info = session.conn:ssl_info();
if type(info) == "table" then
(session.log or log)("info", "Stream encrypted (%s with %s)", info.protocol, info.cipher);
session.compressed = info.compression;
@@ -935,6 +930,16 @@ local function friendly_cert_error(session) --> string
elseif cert_errors:contains("self signed certificate") then
return "is self-signed";
end
+
+ local chain_errors = set.new(session.cert_chain_errors[2]);
+ for i, e in pairs(session.cert_chain_errors) do
+ if i > 2 then chain_errors:add_list(e); end
+ end
+ if chain_errors:contains("certificate has expired") then
+ return "has an expired certificate chain";
+ elseif chain_errors:contains("No matching DANE TLSA records") then
+ return "does not match any DANE TLSA records";
+ end
end
return "is not trusted"; -- for some other reason
elseif session.cert_identity_status == "invalid" then
diff --git a/plugins/mod_s2s_auth_certs.lua b/plugins/mod_s2s_auth_certs.lua
index 992ee934..bde3cb82 100644
--- a/plugins/mod_s2s_auth_certs.lua
+++ b/plugins/mod_s2s_auth_certs.lua
@@ -9,7 +9,7 @@ local measure_cert_statuses = module:metric("counter", "checked", "", "Certifica
module:hook("s2s-check-certificate", function(event)
local session, host, cert = event.session, event.host, event.cert;
- local conn = session.conn:socket();
+ local conn = session.conn;
local log = session.log or log;
if not cert then
@@ -18,8 +18,8 @@ module:hook("s2s-check-certificate", function(event)
end
local chain_valid, errors;
- if conn.getpeerverification then
- chain_valid, errors = conn:getpeerverification();
+ if conn.ssl_peerverification then
+ chain_valid, errors = conn:ssl_peerverification();
else
chain_valid, errors = false, { { "Chain verification not supported by this version of LuaSec" } };
end
diff --git a/plugins/mod_saslauth.lua b/plugins/mod_saslauth.lua
index ab863aa3..0b350c74 100644
--- a/plugins/mod_saslauth.lua
+++ b/plugins/mod_saslauth.lua
@@ -242,7 +242,7 @@ module:hook("stanza/urn:ietf:params:xml:ns:xmpp-sasl:abort", function(event)
end);
local function tls_unique(self)
- return self.userdata["tls-unique"]:getpeerfinished();
+ return self.userdata["tls-unique"]:ssl_peerfinished();
end
local mechanisms_attr = { xmlns='urn:ietf:params:xml:ns:xmpp-sasl' };
@@ -258,22 +258,23 @@ module:hook("stream-features", function(event)
end
local sasl_handler = usermanager_get_sasl_handler(module.host, origin)
origin.sasl_handler = sasl_handler;
+ local channel_bindings = set.new()
if origin.encrypted then
-- check whether LuaSec has the nifty binding to the function needed for tls-unique
-- FIXME: would be nice to have this check only once and not for every socket
if sasl_handler.add_cb_handler then
- local socket = origin.conn:socket();
- local info = socket.info and socket:info();
- if info.protocol == "TLSv1.3" then
+ local info = origin.conn:ssl_info();
+ if info and info.protocol == "TLSv1.3" then
log("debug", "Channel binding 'tls-unique' undefined in context of TLS 1.3");
- elseif socket.getpeerfinished and socket:getpeerfinished() then
+ elseif origin.conn.ssl_peerfinished and origin.conn:ssl_peerfinished() then
log("debug", "Channel binding 'tls-unique' supported");
sasl_handler:add_cb_handler("tls-unique", tls_unique);
+ channel_bindings:add("tls-unique");
else
log("debug", "Channel binding 'tls-unique' not supported (by LuaSec?)");
end
sasl_handler["userdata"] = {
- ["tls-unique"] = socket;
+ ["tls-unique"] = origin.conn;
};
else
log("debug", "Channel binding not supported by SASL handler");
@@ -305,6 +306,14 @@ module:hook("stream-features", function(event)
for mechanism in usable_mechanisms do
mechanisms:tag("mechanism"):text(mechanism):up();
end
+ if not channel_bindings:empty() then
+ -- XXX XEP-0440 is Experimental
+ mechanisms:tag("sasl-channel-binding", {xmlns='urn:xmpp:sasl-cb:0'})
+ for channel_binding in channel_bindings do
+ mechanisms:tag("channel-binding", {type=channel_binding}):up()
+ end
+ mechanisms:up();
+ end
features:add_child(mechanisms);
return;
end
diff --git a/plugins/mod_smacks.lua b/plugins/mod_smacks.lua
index 3a4c7b84..e2bbff9c 100644
--- a/plugins/mod_smacks.lua
+++ b/plugins/mod_smacks.lua
@@ -2,7 +2,7 @@
--
-- Copyright (C) 2010-2015 Matthew Wild
-- Copyright (C) 2010 Waqas Hussain
--- Copyright (C) 2012-2021 Kim Alvefur
+-- Copyright (C) 2012-2022 Kim Alvefur
-- Copyright (C) 2012 Thijs Alkemade
-- Copyright (C) 2014 Florian Zeitz
-- Copyright (C) 2016-2020 Thilo Molitor
@@ -10,6 +10,7 @@
-- This project is MIT/X11 licensed. Please see the
-- COPYING file in the source package for more information.
--
+-- TODO unify sendq and smqueue
local tonumber = tonumber;
local tostring = tostring;
@@ -83,6 +84,22 @@ local all_old_sessions = module:open_store("smacks_h");
local old_session_registry = module:open_store("smacks_h", "map");
local session_registry = module:shared "/*/smacks/resumption-tokens"; -- > user@host/resumption-token --> resource
+local function track_session(session, id)
+ session_registry[jid.join(session.username, session.host, id or session.resumption_token)] = session;
+ session.resumption_token = id;
+end
+
+local function save_old_session(session)
+ session_registry[jid.join(session.username, session.host, session.resumption_token)] = nil;
+ return old_session_registry:set(session.username, session.resumption_token,
+ { h = session.handled_stanza_count; t = os.time() })
+end
+
+local function clear_old_session(session, id)
+ session_registry[jid.join(session.username, session.host, id or session.resumption_token)] = nil;
+ return old_session_registry:set(session.username, id or session.resumption_token, nil)
+end
+
local ack_errors = require"util.error".init("mod_smacks", xmlns_sm3, {
head = { condition = "undefined-condition"; text = "Client acknowledged more stanzas than sent by server" };
tail = { condition = "undefined-condition"; text = "Client acknowledged less stanzas than already acknowledged" };
@@ -155,13 +172,12 @@ end
local function request_ack(session, reason)
local queue = session.outgoing_stanza_queue;
- session.log("debug", "Sending <r> (inside timer, before send) from %s - #queue=%d", reason, queue:count_unacked());
+ session.log("debug", "Sending <r> from %s - #queue=%d", reason, queue:count_unacked());
session.awaiting_ack = true;
(session.sends2s or session.send)(st.stanza("r", { xmlns = session.smacks }))
if session.destroyed then return end -- sending something can trigger destruction
-- expected_h could be lower than this expression e.g. more stanzas added to the queue meanwhile)
session.last_requested_h = queue:count_acked() + queue:count_unacked();
- session.log("debug", "Sending <r> (inside timer, after send) from %s - #queue=%d", reason, queue:count_unacked());
if not session.delayed_ack_timer then
session.delayed_ack_timer = timer.add_task(delayed_ack_timeout, function()
ack_delayed(session, nil); -- we don't know if this is the only new stanza in the queue
@@ -234,8 +250,7 @@ module:hook("pre-session-close", function(event)
if session.smacks == nil then return end
if session.resumption_token then
session.log("debug", "Revoking resumption token");
- session_registry[jid.join(session.username, session.host, session.resumption_token)] = nil;
- old_session_registry:set(session.username, session.resumption_token, nil);
+ clear_old_session(session);
session.resumption_token = nil;
else
session.log("debug", "Session not resumable");
@@ -284,7 +299,7 @@ function handle_enable(session, stanza, xmlns_sm)
if session.username then
local old_sessions, err = all_old_sessions:get(session.username);
- module:log("debug", "Old sessions: %q", old_sessions)
+ session.log("debug", "Old sessions: %q", old_sessions)
if old_sessions then
local keep, count = {}, 0;
for token, info in it.sorted_pairs(old_sessions, function(a, b)
@@ -296,11 +311,11 @@ function handle_enable(session, stanza, xmlns_sm)
end
all_old_sessions:set(session.username, keep);
elseif err then
- module:log("error", "Unable to retrieve old resumption counters: %s", err);
+ session.log("error", "Unable to retrieve old resumption counters: %s", err);
end
end
- module:log("debug", "Enabling stream management");
+ session.log("debug", "Enabling stream management");
session.smacks = xmlns_sm;
wrap_session(session, false);
@@ -310,8 +325,7 @@ function handle_enable(session, stanza, xmlns_sm)
local resume = stanza.attr.resume;
if resume == "true" or resume == "1" then
resume_token = new_id();
- session_registry[jid.join(session.username, session.host, resume_token)] = session;
- session.resumption_token = resume_token;
+ track_session(session, resume_token);
resume_max = tostring(resume_timeout);
end
(session.sends2s or session.send)(st.stanza("enabled", { xmlns = xmlns_sm, id = resume_token, resume = resume, max = resume_max }));
@@ -320,29 +334,23 @@ end
module:hook_tag(xmlns_sm2, "enable", function (session, stanza) return handle_enable(session, stanza, xmlns_sm2); end, 100);
module:hook_tag(xmlns_sm3, "enable", function (session, stanza) return handle_enable(session, stanza, xmlns_sm3); end, 100);
-module:hook_tag("http://etherx.jabber.org/streams", "features",
- function (session, stanza)
- -- Needs to be done after flushing sendq since those aren't stored as
- -- stanzas and counting them is weird.
- -- TODO unify sendq and smqueue
- timer.add_task(1e-6, function ()
- if can_do_smacks(session) then
- if stanza:get_child("sm", xmlns_sm3) then
- session.sends2s(st.stanza("enable", sm3_attr));
- session.smacks = xmlns_sm3;
- elseif stanza:get_child("sm", xmlns_sm2) then
- session.sends2s(st.stanza("enable", sm2_attr));
- session.smacks = xmlns_sm2;
- else
- return;
- end
- wrap_session_out(session, false);
- end
- end);
- end);
+module:hook_tag("http://etherx.jabber.org/streams", "features", function(session, stanza)
+ if can_do_smacks(session) then
+ session.smacks_feature = stanza:get_child("sm", xmlns_sm3) or stanza:get_child("sm", xmlns_sm2);
+ end
+end);
+
+module:hook("s2sout-established", function (event)
+ local session = event.session;
+ if not session.smacks_feature then return end
+
+ session.smacks = session.smacks_feature.attr.xmlns;
+ wrap_session_out(session, false);
+ session.sends2s(st.stanza("enable", { xmlns = session.smacks }));
+end);
function handle_enabled(session, stanza, xmlns_sm) -- luacheck: ignore 212/stanza
- module:log("debug", "Enabling stream management");
+ session.log("debug", "Enabling stream management");
session.smacks = xmlns_sm;
wrap_session_in(session, false);
@@ -356,10 +364,10 @@ module:hook_tag(xmlns_sm3, "enabled", function (session, stanza) return handle_e
function handle_r(origin, stanza, xmlns_sm) -- luacheck: ignore 212/stanza
if not origin.smacks then
- module:log("debug", "Received ack request from non-smack-enabled session");
+ origin.log("debug", "Received ack request from non-smack-enabled session");
return;
end
- module:log("debug", "Received ack request, acking for %d", origin.handled_stanza_count);
+ origin.log("debug", "Received ack request, acking for %d", origin.handled_stanza_count);
-- Reply with <a>
(origin.sends2s or origin.send)(st.stanza("a", { xmlns = xmlns_sm, h = format_h(origin.handled_stanza_count) }));
-- piggyback our own ack request if needed (see request_ack_if_needed() for explanation of last_requested_h)
@@ -412,13 +420,14 @@ local function handle_unacked_stanzas(session)
local queue = session.outgoing_stanza_queue;
local unacked = queue:count_unacked()
if unacked > 0 then
+ local error_from = jid.join(session.username, session.host or module.host);
tx_dropped_stanzas:sample(unacked);
session.smacks = false; -- Disable queueing
session.outgoing_stanza_queue = nil;
for stanza in queue._queue:consume() do
if not module:fire_event("delivery/failure", { session = session, stanza = stanza }) then
if stanza.attr.type ~= "error" and stanza.attr.from ~= session.full_jid then
- local reply = st.error_reply(stanza, "cancel", "recipient-unavailable");
+ local reply = st.error_reply(stanza, "cancel", "recipient-unavailable", nil, error_from);
module:send(reply);
end
end
@@ -485,9 +494,7 @@ module:hook("pre-resource-unbind", function (event)
end
session.log("debug", "Destroying session for hibernating too long");
- session_registry[jid.join(session.username, session.host, session.resumption_token)] = nil;
- old_session_registry:set(session.username, session.resumption_token,
- { h = session.handled_stanza_count; t = os.time() });
+ save_old_session(session);
session.resumption_token = nil;
session.resending_unacked = true; -- stop outgoing_stanza_filter from re-queueing anything anymore
sessionmanager.destroy_session(session, "Hibernating too long");
@@ -544,7 +551,7 @@ function handle_resume(session, stanza, xmlns_sm)
session.send(st.stanza("failed", { xmlns = xmlns_sm, h = format_h(old_session.h) })
:tag("item-not-found", { xmlns = xmlns_errors })
);
- old_session_registry:set(session.username, id, nil);
+ clear_old_session(session, id);
resumption_expired(1);
else
session.log("debug", "Tried to resume non-existent session with id %s", id);
@@ -701,8 +708,7 @@ module:hook_global("server-stopping", function(event)
for _, user in pairs(local_sessions) do
for _, session in pairs(user.sessions) do
if session.resumption_token then
- if old_session_registry:set(session.username, session.resumption_token,
- { h = session.handled_stanza_count; t = os.time() }) then
+ if save_old_session(session) then
session.resumption_token = nil;
-- Deal with unacked stanzas
diff --git a/plugins/mod_tls.lua b/plugins/mod_tls.lua
index afc1653a..fc35b1d0 100644
--- a/plugins/mod_tls.lua
+++ b/plugins/mod_tls.lua
@@ -80,6 +80,9 @@ end
module:hook_global("config-reloaded", module.load);
local function can_do_tls(session)
+ if session.secure then
+ return false;
+ end
if session.conn and not session.conn.starttls then
if not session.secure then
session.log("debug", "Underlying connection does not support STARTTLS");
@@ -126,6 +129,13 @@ end);
module:hook("stanza/urn:ietf:params:xml:ns:xmpp-tls:starttls", function(event)
local origin = event.origin;
if can_do_tls(origin) then
+ if origin.conn.block_reads then
+ -- we need to ensure that no data is read anymore, otherwise we could end up in a situation where
+ -- <proceed/> is sent and the socket receives the TLS handshake (and passes the data to lua) before
+ -- it is asked to initiate TLS
+ -- (not with the classical single-threaded server backends)
+ origin.conn:block_reads()
+ end
(origin.sends2s or origin.send)(starttls_proceed);
if origin.destroyed then return end
origin:reset_stream();
@@ -183,7 +193,7 @@ module:hook_tag(xmlns_starttls, "proceed", function (session, stanza) -- luachec
if session.type == "s2sout_unauthed" and can_do_tls(session) then
module:log("debug", "Proceeding with TLS on s2sout...");
session:reset_stream();
- session.conn:starttls(session.ssl_ctx);
+ session.conn:starttls(session.ssl_ctx, session.to_host);
session.secure = false;
return true;
end
diff --git a/spec/net_resolvers_service_spec.lua b/spec/net_resolvers_service_spec.lua
new file mode 100644
index 00000000..53ce4754
--- /dev/null
+++ b/spec/net_resolvers_service_spec.lua
@@ -0,0 +1,241 @@
+local set = require "util.set";
+
+insulate("net.resolvers.service", function ()
+ local adns = {
+ resolver = function ()
+ return {
+ lookup = function (_, cb, qname, qtype, qclass)
+ if qname == "_xmpp-server._tcp.example.com"
+ and (qtype or "SRV") == "SRV"
+ and (qclass or "IN") == "IN" then
+ cb({
+ { -- 60+35+60
+ srv = { target = "xmpp0-a.example.com", port = 5228, priority = 0, weight = 60 };
+ };
+ {
+ srv = { target = "xmpp0-b.example.com", port = 5216, priority = 0, weight = 35 };
+ };
+ {
+ srv = { target = "xmpp0-c.example.com", port = 5200, priority = 0, weight = 0 };
+ };
+ {
+ srv = { target = "xmpp0-d.example.com", port = 5256, priority = 0, weight = 120 };
+ };
+
+ {
+ srv = { target = "xmpp1-a.example.com", port = 5273, priority = 1, weight = 30 };
+ };
+ {
+ srv = { target = "xmpp1-b.example.com", port = 5274, priority = 1, weight = 30 };
+ };
+
+ {
+ srv = { target = "xmpp2.example.com", port = 5275, priority = 2, weight = 0 };
+ };
+ });
+ elseif qname == "_xmpp-server._tcp.single.example.com"
+ and (qtype or "SRV") == "SRV"
+ and (qclass or "IN") == "IN" then
+ cb({
+ {
+ srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 };
+ };
+ });
+ elseif qname == "_xmpp-server._tcp.half.example.com"
+ and (qtype or "SRV") == "SRV"
+ and (qclass or "IN") == "IN" then
+ cb({
+ {
+ srv = { target = "xmpp0-a.example.com", port = 5269, priority = 0, weight = 0 };
+ };
+ {
+ srv = { target = "xmpp0-b.example.com", port = 5270, priority = 0, weight = 1 };
+ };
+ });
+ elseif qtype == "A" then
+ local l = qname:match("%-(%a)%.example.com$") or "1";
+ local d = ("%d"):format(l:byte())
+ cb({
+ {
+ a = "127.0.0."..d;
+ };
+ });
+ elseif qtype == "AAAA" then
+ local l = qname:match("%-(%a)%.example.com$") or "1";
+ local d = ("%04d"):format(l:byte())
+ cb({
+ {
+ aaaa = "fdeb:9619:649e:c7d9::"..d;
+ };
+ });
+ else
+ cb(nil);
+ end
+ end;
+ };
+ end;
+ };
+ package.loaded["net.adns"] = mock(adns);
+ local resolver = require "net.resolvers.service";
+ math.randomseed(os.time());
+ it("works for 99% of deployments", function ()
+ -- Most deployments only have a single SRV record, let's make
+ -- sure that works okay
+
+ local expected_targets = set.new({
+ -- xmpp0-a
+ "tcp4 127.0.0.97 5269";
+ "tcp6 fdeb:9619:649e:c7d9::0097 5269";
+ });
+ local received_targets = set.new({});
+
+ local r = resolver.new("single.example.com", "xmpp-server");
+ local done = false;
+ local function handle_target(...)
+ if ... == nil then
+ done = true;
+ -- No more targets
+ return;
+ end
+ received_targets:add(table.concat({ ... }, " ", 1, 3));
+ end
+ r:next(handle_target);
+ while not done do
+ r:next(handle_target);
+ end
+
+ -- We should have received all expected targets, and no unexpected
+ -- ones:
+ assert.truthy(set.xor(received_targets, expected_targets):empty());
+ end);
+
+ it("supports A/AAAA fallback", function ()
+ -- Many deployments don't have any SRV records, so we should
+ -- fall back to A/AAAA records instead when that is the case
+
+ local expected_targets = set.new({
+ -- xmpp0-a
+ "tcp4 127.0.0.97 5269";
+ "tcp6 fdeb:9619:649e:c7d9::0097 5269";
+ });
+ local received_targets = set.new({});
+
+ local r = resolver.new("xmpp0-a.example.com", "xmpp-server", "tcp", { default_port = 5269 });
+ local done = false;
+ local function handle_target(...)
+ if ... == nil then
+ done = true;
+ -- No more targets
+ return;
+ end
+ received_targets:add(table.concat({ ... }, " ", 1, 3));
+ end
+ r:next(handle_target);
+ while not done do
+ r:next(handle_target);
+ end
+
+ -- We should have received all expected targets, and no unexpected
+ -- ones:
+ assert.truthy(set.xor(received_targets, expected_targets):empty());
+ end);
+
+
+ it("works", function ()
+ local expected_targets = set.new({
+ -- xmpp0-a
+ "tcp4 127.0.0.97 5228";
+ "tcp6 fdeb:9619:649e:c7d9::0097 5228";
+ "tcp4 127.0.0.97 5273";
+ "tcp6 fdeb:9619:649e:c7d9::0097 5273";
+
+ -- xmpp0-b
+ "tcp4 127.0.0.98 5274";
+ "tcp6 fdeb:9619:649e:c7d9::0098 5274";
+ "tcp4 127.0.0.98 5216";
+ "tcp6 fdeb:9619:649e:c7d9::0098 5216";
+
+ -- xmpp0-c
+ "tcp4 127.0.0.99 5200";
+ "tcp6 fdeb:9619:649e:c7d9::0099 5200";
+
+ -- xmpp0-d
+ "tcp4 127.0.0.100 5256";
+ "tcp6 fdeb:9619:649e:c7d9::0100 5256";
+
+ -- xmpp2
+ "tcp4 127.0.0.49 5275";
+ "tcp6 fdeb:9619:649e:c7d9::0049 5275";
+
+ });
+ local received_targets = set.new({});
+
+ local r = resolver.new("example.com", "xmpp-server");
+ local done = false;
+ local function handle_target(...)
+ if ... == nil then
+ done = true;
+ -- No more targets
+ return;
+ end
+ received_targets:add(table.concat({ ... }, " ", 1, 3));
+ end
+ r:next(handle_target);
+ while not done do
+ r:next(handle_target);
+ end
+
+ -- We should have received all expected targets, and no unexpected
+ -- ones:
+ assert.truthy(set.xor(received_targets, expected_targets):empty());
+ end);
+
+ it("balances across weights correctly #slow", function ()
+ -- This mimics many repeated connections to 'example.com' (mock
+ -- records defined above), and records the port number of the
+ -- first target. Therefore it (should) only return priority
+ -- 0 records, and the input data is constructed such that the
+ -- last two digits of the port number represent the percentage
+ -- of times that record should (on average) be picked first.
+
+ -- To prevent random test failures, we test across a handful
+ -- of fixed (randomly selected) seeds.
+ for _, seed in ipairs({ 8401877, 3943829, 7830992 }) do
+ math.randomseed(seed);
+
+ local results = {};
+ local function run()
+ local run_results = {};
+ local r = resolver.new("example.com", "xmpp-server");
+ local function record_target(...)
+ if ... == nil then
+ -- No more targets
+ return;
+ end
+ run_results = { ... };
+ end
+ r:next(record_target);
+ return run_results[3];
+ end
+
+ for _ = 1, 1000 do
+ local port = run();
+ results[port] = (results[port] or 0) + 1;
+ end
+
+ local ports = {};
+ for port in pairs(results) do
+ table.insert(ports, port);
+ end
+ table.sort(ports);
+ for _, port in ipairs(ports) do
+ --print("PORT", port, tostring((results[port]/1000) * 100).."% hits (expected "..tostring(port-5200).."%)");
+ local hit_pct = (results[port]/1000) * 100;
+ local expected_pct = port - 5200;
+ --print(hit_pct, expected_pct, math.abs(hit_pct - expected_pct));
+ assert.is_true(math.abs(hit_pct - expected_pct) < 5);
+ end
+ --print("---");
+ end
+ end);
+end);
diff --git a/spec/util_poll_spec.lua b/spec/util_poll_spec.lua
index a763be90..05318453 100644
--- a/spec/util_poll_spec.lua
+++ b/spec/util_poll_spec.lua
@@ -1,6 +1,35 @@
-describe("util.poll", function ()
- it("loads", function ()
- require "util.poll"
+describe("util.poll", function()
+ local poll;
+ setup(function()
+ poll = require "util.poll";
end);
+ it("loads", function()
+ assert.is_table(poll);
+ assert.is_function(poll.new);
+ assert.is_string(poll.api);
+ end);
+ describe("new", function()
+ local p;
+ setup(function()
+ p = poll.new();
+ end)
+ it("times out", function ()
+ local fd, err = p:wait(0);
+ assert.falsy(fd);
+ assert.equal("timeout", err);
+ end);
+ it("works", function()
+ -- stdout should be writable, right?
+ assert.truthy(p:add(1, false, true));
+ local fd, r, w = p:wait(1);
+ assert.is_number(fd);
+ assert.is_boolean(r);
+ assert.is_boolean(w);
+ assert.equal(1, fd);
+ assert.falsy(r);
+ assert.truthy(w);
+ assert.truthy(p:del(1));
+ end);
+ end)
end);
diff --git a/spec/util_table_spec.lua b/spec/util_table_spec.lua
index 76f54b69..a0535c08 100644
--- a/spec/util_table_spec.lua
+++ b/spec/util_table_spec.lua
@@ -12,6 +12,17 @@ describe("util.table", function ()
assert.same({ "lorem", "ipsum", "dolor", "sit", "amet", n = 5 }, u_table.pack("lorem", "ipsum", "dolor", "sit", "amet"));
end);
end);
+
+ describe("move()", function ()
+ it("works", function ()
+ local t1 = { "apple", "banana", "carrot" };
+ local t2 = { "cat", "donkey", "elephant" };
+ local t3 = {};
+ u_table.move(t1, 1, 3, 1, t3);
+ u_table.move(t2, 1, 3, 3, t3);
+ assert.same({ "apple", "banana", "cat", "donkey", "elephant" }, t3);
+ end);
+ end);
end);
diff --git a/teal-src/module.d.tl b/teal-src/module.d.tl
index 67b2437c..cb7771e2 100644
--- a/teal-src/module.d.tl
+++ b/teal-src/module.d.tl
@@ -62,7 +62,12 @@ global record moduleapi
send_iq : function (moduleapi, st.stanza_t, util_session, number)
broadcast : function (moduleapi, { string }, st.stanza_t, function)
type timer_callback = function (number, ... : any) : number
- add_timer : function (moduleapi, number, timer_callback, ... : any)
+ record timer_wrapper
+ stop : function (timer_wrapper)
+ disarm : function (timer_wrapper)
+ reschedule : function (timer_wrapper, number)
+ end
+ add_timer : function (moduleapi, number, timer_callback, ... : any) : timer_wrapper
get_directory : function (moduleapi) : string
enum file_mode
"r" "w" "a" "r+" "w+" "a+"
diff --git a/teal-src/plugins/mod_cron.tl b/teal-src/plugins/mod_cron.tl
index f3b8f62f..7fa2a36b 100644
--- a/teal-src/plugins/mod_cron.tl
+++ b/teal-src/plugins/mod_cron.tl
@@ -88,8 +88,8 @@ local function run_task(task : task_spec)
task:save(started_at);
end
-local task_runner = async.runner(run_task);
-module:add_timer(1, function() : integer
+local task_runner : async.runner_t<task_spec> = async.runner(run_task);
+scheduled = module:add_timer(1, function() : integer
module:log("info", "Running periodic tasks");
local delay = 3600;
for host in pairs(active_hosts) do
diff --git a/teal-src/util/async.d.tl b/teal-src/util/async.d.tl
new file mode 100644
index 00000000..a2e41cd6
--- /dev/null
+++ b/teal-src/util/async.d.tl
@@ -0,0 +1,42 @@
+local record lib
+ ready : function () : boolean
+ waiter : function (num : integer, allow_many : boolean) : function (), function ()
+ guarder : function () : function (id : function ()) : function () | nil
+ record runner_t<T>
+ func : function (T)
+ thread : thread
+ enum state_e
+ -- from Lua manual
+ "running"
+ "suspended"
+ "normal"
+ "dead"
+
+ -- from util.async
+ "ready"
+ "error"
+ end
+ state : state_e
+ notified_state : state_e
+ queue : { T }
+ type watcher_t = function (runner_t<T>, ... : any)
+ type watchers_t = { state_e : watcher_t }
+ data : any
+ id : string
+
+ run : function (runner_t<T>, T) : boolean, state_e, integer
+ enqueue : function (runner_t<T>, T) : runner_t<T>
+ log : function (runner_t<T>, string, string, ... : any)
+ onready : function (runner_t<T>, function) : runner_t<T>
+ onready : function (runner_t<T>, function) : runner_t<T>
+ onwaiting : function (runner_t<T>, function) : runner_t<T>
+ onerror : function (runner_t<T>, function) : runner_t<T>
+ end
+ runner : function <T>(function (T), runner_t.watchers_t, any) : runner_t<T>
+ wait_for : function (any) : any, any
+ sleep : function (t:number)
+
+ -- set_nexttick = function(new_next_tick) next_tick = new_next_tick; end;
+ -- set_schedule_function = function (new_schedule_function) schedule_task = new_schedule_function; end;
+end
+return lib
diff --git a/teal-src/util/stanza.d.tl b/teal-src/util/stanza.d.tl
index a358248a..1f565b88 100644
--- a/teal-src/util/stanza.d.tl
+++ b/teal-src/util/stanza.d.tl
@@ -4,6 +4,39 @@ local record lib
type childtags_iter = function () : stanza_t
type maptags_cb = function ( stanza_t ) : stanza_t
+
+ enum stanza_error_type
+ "auth"
+ "cancel"
+ "continue"
+ "modify"
+ "wait"
+ end
+ enum stanza_error_condition
+ "bad-request"
+ "conflict"
+ "feature-not-implemented"
+ "forbidden"
+ "gone"
+ "internal-server-error"
+ "item-not-found"
+ "jid-malformed"
+ "not-acceptable"
+ "not-allowed"
+ "not-authorized"
+ "policy-violation"
+ "recipient-unavailable"
+ "redirect"
+ "registration-required"
+ "remote-server-not-found"
+ "remote-server-timeout"
+ "resource-constraint"
+ "service-unavailable"
+ "subscription-required"
+ "undefined-condition"
+ "unexpected-request"
+ end
+
record stanza_t
name : string
attr : { string : string }
@@ -35,7 +68,7 @@ local record lib
pretty_print : function ( stanza_t ) : string
pretty_top_tag : function ( stanza_t ) : string
- get_error : function ( stanza_t ) : string, string, string, stanza_t
+ get_error : function ( stanza_t ) : stanza_error_type, stanza_error_condition, string, stanza_t
indent : function ( stanza_t, integer, string ) : stanza_t
end
@@ -45,16 +78,61 @@ local record lib
{ serialized_stanza_t | string }
end
+ record message_attr
+ ["xml:lang"] : string
+ from : string
+ id : string
+ to : string
+ type : message_type
+ enum message_type
+ "chat"
+ "error"
+ "groupchat"
+ "headline"
+ "normal"
+ end
+ end
+
+ record presence_attr
+ ["xml:lang"] : string
+ from : string
+ id : string
+ to : string
+ type : presence_type
+ enum presence_type
+ "error"
+ "probe"
+ "subscribe"
+ "subscribed"
+ "unsubscribe"
+ "unsubscribed"
+ end
+ end
+
+ record iq_attr
+ ["xml:lang"] : string
+ from : string
+ id : string
+ to : string
+ type : iq_type
+ enum iq_type
+ "error"
+ "get"
+ "result"
+ "set"
+ end
+ end
+
stanza : function ( string, { string : string } ) : stanza_t
is_stanza : function ( any ) : boolean
preserialize : function ( stanza_t ) : serialized_stanza_t
deserialize : function ( serialized_stanza_t ) : stanza_t
clone : function ( stanza_t, boolean ) : stanza_t
- message : function ( { string : string }, string ) : stanza_t
- iq : function ( { string : string } ) : stanza_t
+ message : function ( message_attr, string ) : stanza_t
+ iq : function ( iq_attr ) : stanza_t
reply : function ( stanza_t ) : stanza_t
- error_reply : function ( stanza_t, string, string, string, string )
- presence : function ( { string : string } ) : stanza_t
+ error_reply : function ( stanza_t, stanza_error_type, stanza_error_condition, string, string ) : stanza_t
+ presence : function ( presence_attr ) : stanza_t
xml_escape : function ( string ) : string
pretty_print : function ( string ) : string
end
diff --git a/util-src/crand.c b/util-src/crand.c
index 160ac1f6..e4104787 100644
--- a/util-src/crand.c
+++ b/util-src/crand.c
@@ -45,7 +45,7 @@
#endif
/* This wasn't present before glibc 2.25 */
-int getrandom(void *buf, size_t buflen, unsigned int flags) {
+static int getrandom(void *buf, size_t buflen, unsigned int flags) {
return syscall(SYS_getrandom, buf, buflen, flags);
}
#else
@@ -66,7 +66,7 @@ int getrandom(void *buf, size_t buflen, unsigned int flags) {
#define SMALLBUFSIZ 32
#endif
-int Lrandom(lua_State *L) {
+static int Lrandom(lua_State *L) {
char smallbuf[SMALLBUFSIZ];
char *buf = &smallbuf[0];
const lua_Integer l = luaL_checkinteger(L, 1);
diff --git a/util-src/strbitop.c b/util-src/strbitop.c
index 89fce661..fda8917a 100644
--- a/util-src/strbitop.c
+++ b/util-src/strbitop.c
@@ -14,7 +14,7 @@
/* TODO Deduplicate code somehow */
-int strop_and(lua_State *L) {
+static int strop_and(lua_State *L) {
luaL_Buffer buf;
size_t a, b, i;
const char *str_a = luaL_checklstring(L, 1, &a);
@@ -35,7 +35,7 @@ int strop_and(lua_State *L) {
return 1;
}
-int strop_or(lua_State *L) {
+static int strop_or(lua_State *L) {
luaL_Buffer buf;
size_t a, b, i;
const char *str_a = luaL_checklstring(L, 1, &a);
@@ -56,7 +56,7 @@ int strop_or(lua_State *L) {
return 1;
}
-int strop_xor(lua_State *L) {
+static int strop_xor(lua_State *L) {
luaL_Buffer buf;
size_t a, b, i;
const char *str_a = luaL_checklstring(L, 1, &a);
diff --git a/util-src/table.c b/util-src/table.c
index 9a9553fc..4bbceedb 100644
--- a/util-src/table.c
+++ b/util-src/table.c
@@ -1,11 +1,21 @@
#include <lua.h>
#include <lauxlib.h>
+#ifndef LUA_MAXINTEGER
+#include <stdint.h>
+#define LUA_MAXINTEGER PTRDIFF_MAX
+#endif
+
+#if (LUA_VERSION_NUM > 501)
+#define lua_equal(L, A, B) lua_compare(L, A, B, LUA_OPEQ)
+#endif
+
static int Lcreate_table(lua_State *L) {
lua_createtable(L, luaL_checkinteger(L, 1), luaL_checkinteger(L, 2));
return 1;
}
+/* COMPAT: w/ Lua pre-5.4 */
static int Lpack(lua_State *L) {
unsigned int n_args = lua_gettop(L);
lua_createtable(L, n_args, 1);
@@ -20,6 +30,40 @@ static int Lpack(lua_State *L) {
return 1;
}
+/* COMPAT: w/ Lua pre-5.4 */
+static int Lmove (lua_State *L) {
+ lua_Integer f = luaL_checkinteger(L, 2);
+ lua_Integer e = luaL_checkinteger(L, 3);
+ lua_Integer t = luaL_checkinteger(L, 4);
+
+ int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */
+ luaL_checktype(L, 1, LUA_TTABLE);
+ luaL_checktype(L, tt, LUA_TTABLE);
+
+ if (e >= f) { /* otherwise, nothing to move */
+ lua_Integer n, i;
+ luaL_argcheck(L, f > 0 || e < LUA_MAXINTEGER + f, 3,
+ "too many elements to move");
+ n = e - f + 1; /* number of elements to move */
+ luaL_argcheck(L, t <= LUA_MAXINTEGER - n + 1, 4,
+ "destination wrap around");
+ if (t > e || t <= f || (tt != 1 && !lua_equal(L, 1, tt))) {
+ for (i = 0; i < n; i++) {
+ lua_rawgeti(L, 1, f + i);
+ lua_rawseti(L, tt, t + i);
+ }
+ } else {
+ for (i = n - 1; i >= 0; i--) {
+ lua_rawgeti(L, 1, f + i);
+ lua_rawseti(L, tt, t + i);
+ }
+ }
+ }
+
+ lua_pushvalue(L, tt); /* return destination table */
+ return 1;
+}
+
int luaopen_util_table(lua_State *L) {
#if (LUA_VERSION_NUM > 501)
luaL_checkversion(L);
@@ -29,5 +73,7 @@ int luaopen_util_table(lua_State *L) {
lua_setfield(L, -2, "create");
lua_pushcfunction(L, Lpack);
lua_setfield(L, -2, "pack");
+ lua_pushcfunction(L, Lmove);
+ lua_setfield(L, -2, "move");
return 1;
}
diff --git a/util/array.lua b/util/array.lua
index c33a5ef1..9d438940 100644
--- a/util/array.lua
+++ b/util/array.lua
@@ -8,6 +8,7 @@
local t_insert, t_sort, t_remove, t_concat
= table.insert, table.sort, table.remove, table.concat;
+local t_move = require "util.table".move;
local setmetatable = setmetatable;
local getmetatable = getmetatable;
@@ -137,13 +138,11 @@ function array_base.slice(outa, ina, i, j)
return outa;
end
- for idx = 1, 1+j-i do
- outa[idx] = ina[i+(idx-1)];
- end
+
+ t_move(ina, i, j, 1, outa);
if ina == outa then
- for idx = 2+j-i, #outa do
- outa[idx] = nil;
- end
+ -- Clear (nil) remainder of range
+ t_move(ina, #outa+1, #outa*2, 2+j-i, ina);
end
return outa;
end
@@ -209,10 +208,7 @@ function array_methods:shuffle()
end
function array_methods:append(ina)
- local len, len2 = #self, #ina;
- for i = 1, len2 do
- self[len+i] = ina[i];
- end
+ t_move(ina, 1, #ina, #self+1, self);
return self;
end
diff --git a/util/logger.lua b/util/logger.lua
index 20a5cef2..148b98dc 100644
--- a/util/logger.lua
+++ b/util/logger.lua
@@ -10,6 +10,7 @@
local pairs = pairs;
local ipairs = ipairs;
local require = require;
+local t_remove = table.remove;
local _ENV = nil;
-- luacheck: std none
@@ -78,6 +79,20 @@ local function add_simple_sink(simple_sink_function, levels)
for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do
add_level_sink(level, sink_function);
end
+ return sink_function;
+end
+
+local function remove_sink(sink_function)
+ local removed;
+ for level, sinks in pairs(level_sinks) do
+ for i = #sinks, 1, -1 do
+ if sinks[i] == sink_function then
+ t_remove(sinks, i);
+ removed = true;
+ end
+ end
+ end
+ return removed;
end
return {
@@ -87,4 +102,5 @@ return {
add_level_sink = add_level_sink;
add_simple_sink = add_simple_sink;
new = make_logger;
+ remove_sink = remove_sink;
};
diff --git a/util/openmetrics.lua b/util/openmetrics.lua
index cb7791ec..634f9de1 100644
--- a/util/openmetrics.lua
+++ b/util/openmetrics.lua
@@ -35,10 +35,11 @@ local t_pack, t_unpack = require "util.table".pack, table.unpack or unpack; --lu
-- `with_partial_label` by the moduleapi in order to pre-set the `host` label
-- on metrics created in non-global modules.
local metric_proxy_mt = {}
+metric_proxy_mt.__name = "metric_proxy"
metric_proxy_mt.__index = metric_proxy_mt
local function new_metric_proxy(metric_family, with_labels_proxy_fun)
- return {
+ return setmetatable({
_family = metric_family,
with_labels = function(self, ...)
return with_labels_proxy_fun(self._family, ...)
@@ -48,7 +49,7 @@ local function new_metric_proxy(metric_family, with_labels_proxy_fun)
return family:with_labels(label, ...)
end)
end
- }
+ }, metric_proxy_mt);
end
-- END of Utility: "metric proxy"
@@ -128,6 +129,7 @@ end
-- BEGIN of generic MetricFamily implementation
local metric_family_mt = {}
+metric_family_mt.__name = "metric_family"
metric_family_mt.__index = metric_family_mt
local function histogram_metric_ctor(orig_ctor, buckets)
@@ -278,6 +280,7 @@ local function compose_name(name, unit)
end
local metric_registry_mt = {}
+metric_registry_mt.__name = "metric_registry"
metric_registry_mt.__index = metric_registry_mt
local function new_metric_registry(backend)
diff --git a/util/prosodyctl/shell.lua b/util/prosodyctl/shell.lua
index bce27b94..8f910301 100644
--- a/util/prosodyctl/shell.lua
+++ b/util/prosodyctl/shell.lua
@@ -4,6 +4,8 @@ local st = require "util.stanza";
local path = require "util.paths";
local parse_args = require "util.argparse".parse;
local unpack = table.unpack or _G.unpack;
+local tc = require "util.termcolours";
+local isatty = require "util.pposix".isatty;
local have_readline, readline = pcall(require, "readline");
@@ -27,7 +29,7 @@ local function read_line(prompt_string)
end
local function send_line(client, line)
- client.send(st.stanza("repl-input"):text(line));
+ client.send(st.stanza("repl-input", { width = os.getenv "COLUMNS" }):text(line));
end
local function repl(client)
@@ -64,6 +66,7 @@ end
local function start(arg) --luacheck: ignore 212/arg
local client = adminstream.client();
local opts, err, where = parse_args(arg);
+ local ttyout = isatty(io.stdout);
if not opts then
if err == "param-not-found" then
@@ -89,11 +92,15 @@ local function start(arg) --luacheck: ignore 212/arg
local errors = 0; -- TODO This is weird, but works for now.
client.events.add_handler("received", function(stanza)
if stanza.name == "repl-output" or stanza.name == "repl-result" then
+ local dest = io.stdout;
if stanza.attr.type == "error" then
errors = errors + 1;
- io.stderr:write(stanza:get_text(), "\n");
+ dest = io.stderr;
+ end
+ if stanza.attr.eol == "0" then
+ dest:write(stanza:get_text());
else
- print(stanza:get_text());
+ dest:write(stanza:get_text(), "\n");
end
end
if stanza.name == "repl-result" then
@@ -118,7 +125,11 @@ local function start(arg) --luacheck: ignore 212/arg
client.events.add_handler("received", function (stanza)
if stanza.name == "repl-output" or stanza.name == "repl-result" then
local result_prefix = stanza.attr.type == "error" and "!" or "|";
- print(result_prefix.." "..stanza:get_text());
+ local out = result_prefix.." "..stanza:get_text();
+ if ttyout and stanza.attr.type == "error" then
+ out = tc.getstring(tc.getstyle("red"), out);
+ end
+ print(out);
end
if stanza.name == "repl-result" then
repl(client);
diff --git a/util/sslconfig.lua b/util/sslconfig.lua
index 6074a1fb..0078365b 100644
--- a/util/sslconfig.lua
+++ b/util/sslconfig.lua
@@ -3,9 +3,12 @@
local type = type;
local pairs = pairs;
local rawset = rawset;
+local rawget = rawget;
+local error = error;
local t_concat = table.concat;
local t_insert = table.insert;
local setmetatable = setmetatable;
+local resolve_path = require"util.paths".resolve_relative_path;
local _ENV = nil;
-- luacheck: std none
@@ -34,7 +37,7 @@ function handlers.options(config, field, new)
options[value] = true;
end
end
- config[field] = options;
+ rawset(config, field, options)
end
handlers.verifyext = handlers.options;
@@ -70,6 +73,20 @@ finalisers.curveslist = finalisers.ciphers;
-- TLS 1.3 ciphers
finalisers.ciphersuites = finalisers.ciphers;
+-- Path expansion
+function finalisers.key(path, config)
+ if type(path) == "string" then
+ return resolve_path(config._basedir, path);
+ else
+ return nil
+ end
+end
+finalisers.certificate = finalisers.key;
+finalisers.cafile = finalisers.key;
+finalisers.capath = finalisers.key;
+-- XXX: copied from core/certmanager.lua, but this seems odd, because it would remove a dhparam function from the config
+finalisers.dhparam = finalisers.key;
+
-- protocol = "x" should enable only that protocol
-- protocol = "x+" should enable x and later versions
@@ -89,37 +106,81 @@ end
-- Merge options from 'new' config into 'config'
local function apply(config, new)
+ rawset(config, "_cache", nil);
if type(new) == "table" then
for field, value in pairs(new) do
- (handlers[field] or rawset)(config, field, value);
+ -- exclude keys which are internal to the config builder
+ if field:sub(1, 1) ~= "_" then
+ (handlers[field] or rawset)(config, field, value);
+ end
end
end
+ return config
end
-- Finalize the config into the form LuaSec expects
local function final(config)
local output = { };
for field, value in pairs(config) do
- output[field] = (finalisers[field] or id)(value);
+ -- exclude keys which are internal to the config builder
+ if field:sub(1, 1) ~= "_" then
+ output[field] = (finalisers[field] or id)(value, config);
+ end
end
-- Need to handle protocols last because it adds to the options list
protocol(output);
return output;
end
+local function build(config)
+ local cached = rawget(config, "_cache");
+ if cached then
+ return cached, nil
+ end
+
+ local ctx, err = rawget(config, "_context_factory")(config:final(), config);
+ if ctx then
+ rawset(config, "_cache", ctx);
+ end
+ return ctx, err
+end
+
local sslopts_mt = {
__index = {
apply = apply;
final = final;
+ build = build;
};
+ __newindex = function()
+ error("SSL config objects cannot be modified directly. Use :apply()")
+ end;
};
-local function new()
- return setmetatable({options={}}, sslopts_mt);
+
+-- passing basedir through everything is required to avoid sslconfig depending
+-- on prosody.paths.config
+local function new(context_factory, basedir)
+ return setmetatable({
+ _context_factory = context_factory,
+ _basedir = basedir,
+ options={},
+ }, sslopts_mt);
end
+local function clone(config)
+ local result = new();
+ for k, v in pairs(config) do
+ -- note that we *do* copy the internal keys on clone -- we have to carry
+ -- both the factory and the cache with us
+ rawset(result, k, v);
+ end
+ return result
+end
+
+sslopts_mt.__index.clone = clone;
+
return {
apply = apply;
final = final;
- new = new;
+ _new = new;
};
diff --git a/util/stanza.lua b/util/stanza.lua
index a38f80b3..a14be5f0 100644
--- a/util/stanza.lua
+++ b/util/stanza.lua
@@ -21,6 +21,8 @@ local type = type;
local s_gsub = string.gsub;
local s_sub = string.sub;
local s_find = string.find;
+local t_move = table.move or require "util.table".move;
+local t_create = require"util.table".create;
local valid_utf8 = require "util.encodings".utf8.valid;
@@ -275,25 +277,33 @@ function stanza_mt:find(path)
end
local function _clone(stanza, only_top)
- local attr, tags = {}, {};
+ local attr = {};
for k,v in pairs(stanza.attr) do attr[k] = v; end
local old_namespaces, namespaces = stanza.namespaces;
if old_namespaces then
namespaces = {};
for k,v in pairs(old_namespaces) do namespaces[k] = v; end
end
- local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
+ local tags, new;
+ if only_top then
+ tags = {};
+ new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags };
+ else
+ tags = t_create(#stanza.tags, 0);
+ new = t_create(#stanza, 4);
+ new.name = stanza.name;
+ new.attr = attr;
+ new.namespaces = namespaces;
+ new.tags = tags;
+ end
+
+ setmetatable(new, stanza_mt);
if not only_top then
- for i=1,#stanza do
- local child = stanza[i];
- if child.name then
- child = _clone(child);
- t_insert(tags, child);
- end
- t_insert(new, child);
- end
+ t_move(stanza, 1, #stanza, 1, new);
+ t_move(stanza.tags, 1, #stanza.tags, 1, tags);
+ new:maptags(_clone);
end
- return setmetatable(new, stanza_mt);
+ return new;
end
local function clone(stanza, only_top)
diff --git a/util/vcard.lua b/util/vcard.lua
deleted file mode 100644
index e311f73f..00000000
--- a/util/vcard.lua
+++ /dev/null
@@ -1,574 +0,0 @@
--- Copyright (C) 2011-2014 Kim Alvefur
---
--- This project is MIT/X11 licensed. Please see the
--- COPYING file in the source package for more information.
---
-
--- TODO
--- Fix folding.
-
-local st = require "util.stanza";
-local t_insert, t_concat = table.insert, table.concat;
-local type = type;
-local pairs, ipairs = pairs, ipairs;
-
-local from_text, to_text, from_xep54, to_xep54;
-
-local line_sep = "\n";
-
-local vCard_dtd; -- See end of file
-local vCard4_dtd;
-
-local function vCard_esc(s)
- return s:gsub("[,:;\\]", "\\%1"):gsub("\n","\\n");
-end
-
-local function vCard_unesc(s)
- return s:gsub("\\?[\\nt:;,]", {
- ["\\\\"] = "\\",
- ["\\n"] = "\n",
- ["\\r"] = "\r",
- ["\\t"] = "\t",
- ["\\:"] = ":", -- FIXME Shouldn't need to escape : in values, just params
- ["\\;"] = ";",
- ["\\,"] = ",",
- [":"] = "\29",
- [";"] = "\30",
- [","] = "\31",
- });
-end
-
-local function item_to_xep54(item)
- local t = st.stanza(item.name, { xmlns = "vcard-temp" });
-
- local prop_def = vCard_dtd[item.name];
- if prop_def == "text" then
- t:text(item[1]);
- elseif type(prop_def) == "table" then
- if prop_def.types and item.TYPE then
- if type(item.TYPE) == "table" then
- for _,v in pairs(prop_def.types) do
- for _,typ in pairs(item.TYPE) do
- if typ:upper() == v then
- t:tag(v):up();
- break;
- end
- end
- end
- else
- t:tag(item.TYPE:upper()):up();
- end
- end
-
- if prop_def.props then
- for _,prop in pairs(prop_def.props) do
- if item[prop] then
- for _, v in ipairs(item[prop]) do
- t:text_tag(prop, v);
- end
- end
- end
- end
-
- if prop_def.value then
- t:text_tag(prop_def.value, item[1]);
- elseif prop_def.values then
- local prop_def_values = prop_def.values;
- local repeat_last = prop_def_values.behaviour == "repeat-last" and prop_def_values[#prop_def_values];
- for i=1,#item do
- t:text_tag(prop_def.values[i] or repeat_last, item[i]);
- end
- end
- end
-
- return t;
-end
-
-local function vcard_to_xep54(vCard)
- local t = st.stanza("vCard", { xmlns = "vcard-temp" });
- for i=1,#vCard do
- t:add_child(item_to_xep54(vCard[i]));
- end
- return t;
-end
-
-function to_xep54(vCards)
- if not vCards[1] or vCards[1].name then
- return vcard_to_xep54(vCards)
- else
- local t = st.stanza("xCard", { xmlns = "vcard-temp" });
- for i=1,#vCards do
- t:add_child(vcard_to_xep54(vCards[i]));
- end
- return t;
- end
-end
-
-function from_text(data)
- data = data -- unfold and remove empty lines
- :gsub("\r\n","\n")
- :gsub("\n ", "")
- :gsub("\n\n+","\n");
- local vCards = {};
- local current;
- for line in data:gmatch("[^\n]+") do
- line = vCard_unesc(line);
- local name, params, value = line:match("^([-%a]+)(\30?[^\29]*)\29(.*)$");
- value = value:gsub("\29",":");
- if #params > 0 then
- local _params = {};
- for k,isval,v in params:gmatch("\30([^=]+)(=?)([^\30]*)") do
- k = k:upper();
- local _vt = {};
- for _p in v:gmatch("[^\31]+") do
- _vt[#_vt+1]=_p
- _vt[_p]=true;
- end
- if isval == "=" then
- _params[k]=_vt;
- else
- _params[k]=true;
- end
- end
- params = _params;
- end
- if name == "BEGIN" and value == "VCARD" then
- current = {};
- vCards[#vCards+1] = current;
- elseif name == "END" and value == "VCARD" then
- current = nil;
- elseif current and vCard_dtd[name] then
- local dtd = vCard_dtd[name];
- local item = { name = name };
- t_insert(current, item);
- local up = current;
- current = item;
- if dtd.types then
- for _, t in ipairs(dtd.types) do
- t = t:lower();
- if ( params.TYPE and params.TYPE[t] == true)
- or params[t] == true then
- current.TYPE=t;
- end
- end
- end
- if dtd.props then
- for _, p in ipairs(dtd.props) do
- if params[p] then
- if params[p] == true then
- current[p]=true;
- else
- for _, prop in ipairs(params[p]) do
- current[p]=prop;
- end
- end
- end
- end
- end
- if dtd == "text" or dtd.value then
- t_insert(current, value);
- elseif dtd.values then
- for p in ("\30"..value):gmatch("\30([^\30]*)") do
- t_insert(current, p);
- end
- end
- current = up;
- end
- end
- return vCards;
-end
-
-local function item_to_text(item)
- local value = {};
- for i=1,#item do
- value[i] = vCard_esc(item[i]);
- end
- value = t_concat(value, ";");
-
- local params = "";
- for k,v in pairs(item) do
- if type(k) == "string" and k ~= "name" then
- params = params .. (";%s=%s"):format(k, type(v) == "table" and t_concat(v,",") or v);
- end
- end
-
- return ("%s%s:%s"):format(item.name, params, value)
-end
-
-local function vcard_to_text(vcard)
- local t={};
- t_insert(t, "BEGIN:VCARD")
- for i=1,#vcard do
- t_insert(t, item_to_text(vcard[i]));
- end
- t_insert(t, "END:VCARD")
- return t_concat(t, line_sep);
-end
-
-function to_text(vCards)
- if vCards[1] and vCards[1].name then
- return vcard_to_text(vCards)
- else
- local t = {};
- for i=1,#vCards do
- t[i]=vcard_to_text(vCards[i]);
- end
- return t_concat(t, line_sep);
- end
-end
-
-local function from_xep54_item(item)
- local prop_name = item.name;
- local prop_def = vCard_dtd[prop_name];
-
- local prop = { name = prop_name };
-
- if prop_def == "text" then
- prop[1] = item:get_text();
- elseif type(prop_def) == "table" then
- if prop_def.value then --single item
- prop[1] = item:get_child_text(prop_def.value) or "";
- elseif prop_def.values then --array
- local value_names = prop_def.values;
- if value_names.behaviour == "repeat-last" then
- for i=1,#item.tags do
- t_insert(prop, item.tags[i]:get_text() or "");
- end
- else
- for i=1,#value_names do
- t_insert(prop, item:get_child_text(value_names[i]) or "");
- end
- end
- elseif prop_def.names then
- local names = prop_def.names;
- for i=1,#names do
- if item:get_child(names[i]) then
- prop[1] = names[i];
- break;
- end
- end
- end
-
- if prop_def.props_verbatim then
- for k,v in pairs(prop_def.props_verbatim) do
- prop[k] = v;
- end
- end
-
- if prop_def.types then
- local types = prop_def.types;
- prop.TYPE = {};
- for i=1,#types do
- if item:get_child(types[i]) then
- t_insert(prop.TYPE, types[i]:lower());
- end
- end
- if #prop.TYPE == 0 then
- prop.TYPE = nil;
- end
- end
-
- -- A key-value pair, within a key-value pair?
- if prop_def.props then
- local params = prop_def.props;
- for i=1,#params do
- local name = params[i]
- local data = item:get_child_text(name);
- if data then
- prop[name] = prop[name] or {};
- t_insert(prop[name], data);
- end
- end
- end
- else
- return nil
- end
-
- return prop;
-end
-
-local function from_xep54_vCard(vCard)
- local tags = vCard.tags;
- local t = {};
- for i=1,#tags do
- t_insert(t, from_xep54_item(tags[i]));
- end
- return t
-end
-
-function from_xep54(vCard)
- if vCard.attr.xmlns ~= "vcard-temp" then
- return nil, "wrong-xmlns";
- end
- if vCard.name == "xCard" then -- A collection of vCards
- local t = {};
- local vCards = vCard.tags;
- for i=1,#vCards do
- t[i] = from_xep54_vCard(vCards[i]);
- end
- return t
- elseif vCard.name == "vCard" then -- A single vCard
- return from_xep54_vCard(vCard)
- end
-end
-
-local vcard4 = { }
-
-function vcard4:text(node, params, value) -- luacheck: ignore 212/params
- self:tag(node:lower())
- -- FIXME params
- if type(value) == "string" then
- self:text_tag("text", value);
- elseif vcard4[node] then
- vcard4[node](value);
- end
- self:up();
-end
-
-function vcard4.N(value)
- for i, k in ipairs(vCard_dtd.N.values) do
- value:text_tag(k, value[i]);
- end
-end
-
-local xmlns_vcard4 = "urn:ietf:params:xml:ns:vcard-4.0"
-
-local function item_to_vcard4(item)
- local typ = item.name:lower();
- local t = st.stanza(typ, { xmlns = xmlns_vcard4 });
-
- local prop_def = vCard4_dtd[typ];
- if prop_def == "text" then
- t:text_tag("text", item[1]);
- elseif prop_def == "uri" then
- if item.ENCODING and item.ENCODING[1] == 'b' then
- t:text_tag("uri", "data:;base64," .. item[1]);
- else
- t:text_tag("uri", item[1]);
- end
- elseif type(prop_def) == "table" then
- if prop_def.values then
- for i, v in ipairs(prop_def.values) do
- t:text_tag(v:lower(), item[i]);
- end
- else
- t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"})
- end
- else
- t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"})
- end
- return t;
-end
-
-local function vcard_to_vcard4xml(vCard)
- local t = st.stanza("vcard", { xmlns = xmlns_vcard4 });
- for i=1,#vCard do
- t:add_child(item_to_vcard4(vCard[i]));
- end
- return t;
-end
-
-local function vcards_to_vcard4xml(vCards)
- if not vCards[1] or vCards[1].name then
- return vcard_to_vcard4xml(vCards)
- else
- local t = st.stanza("vcards", { xmlns = xmlns_vcard4 });
- for i=1,#vCards do
- t:add_child(vcard_to_vcard4xml(vCards[i]));
- end
- return t;
- end
-end
-
--- This was adapted from http://xmpp.org/extensions/xep-0054.html#dtd
-vCard_dtd = {
- VERSION = "text", --MUST be 3.0, so parsing is redundant
- FN = "text",
- N = {
- values = {
- "FAMILY",
- "GIVEN",
- "MIDDLE",
- "PREFIX",
- "SUFFIX",
- },
- },
- NICKNAME = "text",
- PHOTO = {
- props_verbatim = { ENCODING = { "b" } },
- props = { "TYPE" },
- value = "BINVAL", --{ "EXTVAL", },
- },
- BDAY = "text",
- ADR = {
- types = {
- "HOME",
- "WORK",
- "POSTAL",
- "PARCEL",
- "DOM",
- "INTL",
- "PREF",
- },
- values = {
- "POBOX",
- "EXTADD",
- "STREET",
- "LOCALITY",
- "REGION",
- "PCODE",
- "CTRY",
- }
- },
- LABEL = {
- types = {
- "HOME",
- "WORK",
- "POSTAL",
- "PARCEL",
- "DOM",
- "INTL",
- "PREF",
- },
- value = "LINE",
- },
- TEL = {
- types = {
- "HOME",
- "WORK",
- "VOICE",
- "FAX",
- "PAGER",
- "MSG",
- "CELL",
- "VIDEO",
- "BBS",
- "MODEM",
- "ISDN",
- "PCS",
- "PREF",
- },
- value = "NUMBER",
- },
- EMAIL = {
- types = {
- "HOME",
- "WORK",
- "INTERNET",
- "PREF",
- "X400",
- },
- value = "USERID",
- },
- JABBERID = "text",
- MAILER = "text",
- TZ = "text",
- GEO = {
- values = {
- "LAT",
- "LON",
- },
- },
- TITLE = "text",
- ROLE = "text",
- LOGO = "copy of PHOTO",
- AGENT = "text",
- ORG = {
- values = {
- behaviour = "repeat-last",
- "ORGNAME",
- "ORGUNIT",
- }
- },
- CATEGORIES = {
- values = "KEYWORD",
- },
- NOTE = "text",
- PRODID = "text",
- REV = "text",
- SORTSTRING = "text",
- SOUND = "copy of PHOTO",
- UID = "text",
- URL = "text",
- CLASS = {
- names = { -- The item.name is the value if it's one of these.
- "PUBLIC",
- "PRIVATE",
- "CONFIDENTIAL",
- },
- },
- KEY = {
- props = { "TYPE" },
- value = "CRED",
- },
- DESC = "text",
-};
-vCard_dtd.LOGO = vCard_dtd.PHOTO;
-vCard_dtd.SOUND = vCard_dtd.PHOTO;
-
-vCard4_dtd = {
- source = "uri",
- kind = "text",
- xml = "text",
- fn = "text",
- n = {
- values = {
- "family",
- "given",
- "middle",
- "prefix",
- "suffix",
- },
- },
- nickname = "text",
- photo = "uri",
- bday = "date-and-or-time",
- anniversary = "date-and-or-time",
- gender = "text",
- adr = {
- values = {
- "pobox",
- "ext",
- "street",
- "locality",
- "region",
- "code",
- "country",
- }
- },
- tel = "text",
- email = "text",
- impp = "uri",
- lang = "language-tag",
- tz = "text",
- geo = "uri",
- title = "text",
- role = "text",
- logo = "uri",
- org = "text",
- member = "uri",
- related = "uri",
- categories = "text",
- note = "text",
- prodid = "text",
- rev = "timestamp",
- sound = "uri",
- uid = "uri",
- clientpidmap = "number, uuid",
- url = "uri",
- version = "text",
- key = "uri",
- fburl = "uri",
- caladruri = "uri",
- caluri = "uri",
-};
-
-return {
- from_text = from_text;
- to_text = to_text;
-
- from_xep54 = from_xep54;
- to_xep54 = to_xep54;
-
- to_vcard4 = vcards_to_vcard4xml;
-};
diff --git a/util/watchdog.lua b/util/watchdog.lua
index 516e60e4..407028a5 100644
--- a/util/watchdog.lua
+++ b/util/watchdog.lua
@@ -1,6 +1,5 @@
local timer = require "util.timer";
local setmetatable = setmetatable;
-local os_time = os.time;
local _ENV = nil;
-- luacheck: std none
@@ -9,27 +8,35 @@ local watchdog_methods = {};
local watchdog_mt = { __index = watchdog_methods };
local function new(timeout, callback)
- local watchdog = setmetatable({ timeout = timeout, last_reset = os_time(), callback = callback }, watchdog_mt);
- timer.add_task(timeout+1, function (current_time)
- local last_reset = watchdog.last_reset;
- if not last_reset then
- return;
- end
- local time_left = (last_reset + timeout) - current_time;
- if time_left < 0 then
- return watchdog:callback();
- end
- return time_left + 1;
- end);
+ local watchdog = setmetatable({
+ timeout = timeout;
+ callback = callback;
+ timer_id = nil;
+ }, watchdog_mt);
+
+ watchdog:reset(); -- Kick things off
+
return watchdog;
end
-function watchdog_methods:reset()
- self.last_reset = os_time();
+function watchdog_methods:reset(new_timeout)
+ if new_timeout then
+ self.timeout = new_timeout;
+ end
+ if self.timer_id then
+ timer.reschedule(self.timer_id, self.timeout+1);
+ else
+ self.timer_id = timer.add_task(self.timeout+1, function ()
+ return self:callback();
+ end);
+ end
end
function watchdog_methods:cancel()
- self.last_reset = nil;
+ if self.timer_id then
+ timer.stop(self.timer_id);
+ self.timer_id = nil;
+ end
end
return {