aboutsummaryrefslogtreecommitdiffstats
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/certmanager.lua136
-rw-r--r--core/configmanager.lua48
-rw-r--r--core/hostmanager.lua1
-rw-r--r--core/loggingmanager.lua38
-rw-r--r--core/moduleapi.lua167
-rw-r--r--core/modulemanager.lua69
-rw-r--r--core/portmanager.lua104
-rw-r--r--core/rostermanager.lua29
-rw-r--r--core/s2smanager.lua51
-rw-r--r--core/sessionmanager.lua42
-rw-r--r--core/stanza_router.lua28
-rw-r--r--core/statsmanager.lua226
-rw-r--r--core/storagemanager.lua33
-rw-r--r--core/usermanager.lua112
14 files changed, 911 insertions, 173 deletions
diff --git a/core/certmanager.lua b/core/certmanager.lua
index d8d07636..7c7fc150 100644
--- a/core/certmanager.lua
+++ b/core/certmanager.lua
@@ -20,20 +20,26 @@ end
local configmanager = require "core.configmanager";
local log = require "util.logger".init("certmanager");
local ssl_context = ssl.context or softreq"ssl.context";
-local ssl_x509 = ssl.x509 or softreq"ssl.x509";
local ssl_newcontext = ssl.newcontext;
local new_config = require"util.sslconfig".new;
local stat = require "lfs".attributes;
+local x509 = require "util.x509";
+local lfs = require "lfs";
+
local tonumber, tostring = tonumber, tostring;
local pairs = pairs;
local t_remove = table.remove;
local type = type;
local io_open = io.open;
local select = select;
+local now = os.time;
+local next = next;
+local pcall = pcall;
local prosody = prosody;
-local resolve_path = require"util.paths".resolve_relative_path;
+local pathutil = require"util.paths";
+local resolve_path = pathutil.resolve_relative_path;
local config_path = prosody.paths.config or ".";
local function test_option(option)
@@ -81,7 +87,7 @@ local function find_cert(user_certs, name)
if crt_path == key_path then
if key_path:sub(-4) == ".crt" then
key_path = key_path:sub(1, -4) .. "key";
- elseif key_path:sub(-13) == "fullchain.pem" then
+ elseif key_path:sub(-14) == "/fullchain.pem" then
key_path = key_path:sub(1, -14) .. "privkey.pem";
end
end
@@ -95,12 +101,107 @@ local function find_cert(user_certs, name)
log("debug", "No certificate/key found for %s", name);
end
+local function find_matching_key(cert_path)
+ -- FIXME we shouldn't need to guess the key filename
+ if cert_path:sub(-4) == ".crt" then
+ return cert_path:sub(1, -4) .. "key";
+ elseif cert_path:sub(-14) == "/fullchain.pem" then
+ return cert_path:sub(1, -14) .. "privkey.pem";
+ end
+end
+
+local function index_certs(dir, files_by_name, depth_limit)
+ files_by_name = files_by_name or {};
+ depth_limit = depth_limit or 3;
+ if depth_limit <= 0 then return files_by_name; end
+
+ local ok, iter, v, i = pcall(lfs.dir, dir);
+ if not ok then
+ log("error", "Error indexing certificate directory %s: %s", dir, iter);
+ -- Return an empty index, otherwise this just triggers a nil indexing
+ -- error, plus this function would get called again.
+ -- Reloading the config after correcting the problem calls this again so
+ -- that's what should be done.
+ return {}, iter;
+ end
+ for file in iter, v, i do
+ local full = pathutil.join(dir, file);
+ if lfs.attributes(full, "mode") == "directory" then
+ if file:sub(1,1) ~= "." then
+ index_certs(full, files_by_name, depth_limit-1);
+ end
+ -- TODO support more filename patterns?
+ elseif full:match("%.crt$") or full:match("/fullchain%.pem$") then
+ local f = io_open(full);
+ if f then
+ -- TODO look for chained certificates
+ local firstline = f:read();
+ if firstline == "-----BEGIN CERTIFICATE-----" then
+ f:seek("set")
+ local cert = ssl.loadcertificate(f:read("*a"))
+ -- TODO if more than one cert is found for a name, the most recently
+ -- issued one should be used.
+ -- for now, just filter out expired certs
+ -- TODO also check if there's a corresponding key
+ if cert:validat(now()) then
+ local names = x509.get_identities(cert);
+ log("debug", "Found certificate %s with identities %q", full, names);
+ for name, services in pairs(names) do
+ -- TODO check services
+ if files_by_name[name] then
+ files_by_name[name][full] = services;
+ else
+ files_by_name[name] = { [full] = services; };
+ end
+ end
+ end
+ end
+ f:close();
+ end
+ end
+ end
+ log("debug", "Certificate index: %q", files_by_name);
+ -- | hostname | filename | service |
+ return files_by_name;
+end
+
+local cert_index;
+
local function find_host_cert(host)
if not host then return nil; end
+ if not cert_index then
+ cert_index = index_certs(resolve_path(config_path, global_certificates));
+ end
+ local certs = cert_index[host];
+ if certs then
+ local cert_filename, services = next(certs);
+ if services["*"] then
+ log("debug", "Using cert %q from index", cert_filename);
+ return {
+ certificate = cert_filename,
+ key = find_matching_key(cert_filename),
+ }
+ end
+ end
+
return find_cert(configmanager.get(host, "certificate"), host) or find_host_cert(host:match("%.(.+)$"));
end
local function find_service_cert(service, port)
+ if not cert_index then
+ cert_index = index_certs(resolve_path(config_path, global_certificates));
+ end
+ for _, certs in pairs(cert_index) do
+ for cert_filename, services in pairs(certs) do
+ if services[service] or services["*"] then
+ log("debug", "Using cert %q from index", cert_filename);
+ return {
+ certificate = cert_filename,
+ key = find_matching_key(cert_filename),
+ }
+ end
+ end
+ end
local cert_config = configmanager.get("*", service.."_certificate");
if type(cert_config) == "table" then
cert_config = cert_config[port] or cert_config.default;
@@ -113,7 +214,7 @@ local core_defaults = {
capath = "/etc/ssl/certs";
depth = 9;
protocol = "tlsv1+";
- verify = (ssl_x509 and { "peer", "client_once", }) or "none";
+ verify = "none";
options = {
cipher_server_preference = luasec_has.options.cipher_server_preference;
no_ticket = luasec_has.options.no_ticket;
@@ -122,7 +223,10 @@ local core_defaults = {
single_ecdh_use = luasec_has.options.single_ecdh_use;
no_renegotiation = luasec_has.options.no_renegotiation;
};
- verifyext = { "lsec_continue", "lsec_ignore_purpose" };
+ verifyext = {
+ "lsec_continue", -- Continue past certificate verification errors
+ "lsec_ignore_purpose", -- Validate client certificates as if they were server certificates
+ };
curve = luasec_has.algorithms.ec and not luasec_has.capabilities.curves_list and "secp384r1";
curveslist = {
"X25519",
@@ -140,6 +244,7 @@ local core_defaults = {
"!3DES", -- 3DES - slow and of questionable security
"!aNULL", -- Ciphers that does not authenticate the connection
};
+ dane = configmanager.get("*", "use_dane");
}
if luasec_has.curves then
@@ -156,20 +261,16 @@ local path_options = { -- These we pass through resolve_path()
key = true, certificate = true, cafile = true, capath = true, dhparam = true
}
-if luasec_version < 5 and ssl_x509 then
- -- COMPAT mw/luasec-hg
- for i=1,#core_defaults.verifyext do -- Remove lsec_ prefix
- core_defaults.verify[#core_defaults.verify+1] = core_defaults.verifyext[i]:sub(6);
- end
-end
-
local function create_context(host, mode, ...)
local cfg = new_config();
cfg:apply(core_defaults);
local service_name, port = host:match("^(%S+) port (%d+)$");
- if service_name then
+ -- port 0 is used with client-only things that normally don't need certificates, e.g. https
+ if service_name and port ~= "0" then
+ log("debug", "Automatically locating certs for service %s on port %s", service_name, port);
cfg:apply(find_service_cert(service_name, tonumber(port)));
else
+ log("debug", "Automatically locating certs for host %s", host);
cfg:apply(find_host_cert(host));
end
cfg:apply({
@@ -185,8 +286,10 @@ local function create_context(host, mode, ...)
local user_ssl_config = cfg:final();
if mode == "server" then
- if not user_ssl_config.certificate then return nil, "No certificate present in SSL/TLS configuration for "..host; end
- if not user_ssl_config.key then return nil, "No key present in SSL/TLS configuration for "..host; end
+ if not user_ssl_config.certificate then
+ log("info", "No certificate present in SSL/TLS configuration for %s. SNI will be required.", host);
+ end
+ if user_ssl_config.certificate and not user_ssl_config.key then return nil, "No key present in SSL/TLS configuration for "..host; end
end
for option in pairs(path_options) do
@@ -258,6 +361,8 @@ local function reload_ssl_config()
if luasec_has.options.no_compression then
core_defaults.options.no_compression = configmanager.get("*", "ssl_compression") ~= true;
end
+ core_defaults.dane = configmanager.get("*", "use_dane") or false;
+ cert_index = index_certs(resolve_path(config_path, global_certificates));
end
prosody.events.add_handler("config-reloaded", reload_ssl_config);
@@ -266,4 +371,5 @@ return {
create_context = create_context;
reload_ssl_config = reload_ssl_config;
find_cert = find_cert;
+ find_host_cert = find_host_cert;
};
diff --git a/core/configmanager.lua b/core/configmanager.lua
index 1e67da9b..ae0a274a 100644
--- a/core/configmanager.lua
+++ b/core/configmanager.lua
@@ -7,15 +7,16 @@
--
local _G = _G;
-local setmetatable, rawget, rawset, io, os, error, dofile, type, pairs =
- setmetatable, rawget, rawset, io, os, error, dofile, type, pairs;
-local format, math_max = string.format, math.max;
+local setmetatable, rawget, rawset, io, os, error, dofile, type, pairs, ipairs =
+ setmetatable, rawget, rawset, io, os, error, dofile, type, pairs, ipairs;
+local format, math_max, t_insert = string.format, math.max, table.insert;
local envload = require"util.envload".envload;
local deps = require"util.dependencies";
local resolve_relative_path = require"util.paths".resolve_relative_path;
local glob_to_pattern = require"util.paths".glob_to_pattern;
local path_sep = package.config:sub(1,1);
+local get_traceback_table = require "util.debug".get_traceback_table;
local encodings = deps.softreq"util.encodings";
local nameprep = encodings and encodings.stringprep.nameprep or function (host) return host:lower(); end
@@ -100,8 +101,18 @@ end
-- Built-in Lua parser
do
local pcall = _G.pcall;
+ local function get_line_number(config_file)
+ local tb = get_traceback_table(nil, 2);
+ for i = 1, #tb do
+ if tb[i].info.short_src == config_file then
+ return tb[i].info.currentline;
+ end
+ end
+ end
parser = {};
function parser.load(data, config_file, config_table)
+ local set_options = {}; -- set_options[host.."/"..option_name] = true (when the option has been set already in this file)
+ local warnings = {};
local env;
-- The ' = true' are needed so as not to set off __newindex when we assign the functions below
env = setmetatable({
@@ -115,13 +126,26 @@ do
return rawget(_G, k);
end,
__newindex = function (_, k, v)
+ local host = env.__currenthost or "*";
+ local option_path = host.."/"..k;
+ if set_options[option_path] then
+ t_insert(warnings, ("%s:%d: Duplicate option '%s'"):format(config_file, get_line_number(config_file), k));
+ end
+ set_options[option_path] = true;
set(config_table, env.__currenthost or "*", k, v);
end
});
rawset(env, "__currenthost", "*") -- Default is global
function env.VirtualHost(name)
- name = nameprep(name);
+ if not name then
+ error("Host must have a name", 2);
+ end
+ local prepped_name = nameprep(name);
+ if not prepped_name then
+ error(format("Name of Host %q contains forbidden characters", name), 0);
+ end
+ name = prepped_name;
if rawget(config_table, name) and rawget(config_table[name], "component_module") then
error(format("Host %q clashes with previously defined %s Component %q, for services use a sub-domain like conference.%s",
name, config_table[name].component_module:gsub("^%a+$", { component = "external", muc = "MUC"}), name, name), 0);
@@ -139,7 +163,14 @@ do
env.Host, env.host = env.VirtualHost, env.VirtualHost;
function env.Component(name)
- name = nameprep(name);
+ if not name then
+ error("Component must have a name", 2);
+ end
+ local prepped_name = nameprep(name);
+ if not prepped_name then
+ error(format("Name of Component %q contains forbidden characters", name), 0);
+ end
+ name = prepped_name;
if rawget(config_table, name) and rawget(config_table[name], "defined")
and not rawget(config_table[name], "component_module") then
error(format("Component %q clashes with previously defined Host %q, for services use a sub-domain like conference.%s",
@@ -195,6 +226,11 @@ do
if f then
local ret, err = parser.load(f:read("*a"), file, config_table);
if not ret then error(err:gsub("%[string.-%]", file), 0); end
+ if err then
+ for _, warning in ipairs(err) do
+ t_insert(warnings, warning);
+ end
+ end
end
if not f then error("Error loading included "..file..": "..err, 0); end
return f, err;
@@ -217,7 +253,7 @@ do
return nil, err;
end
- return true;
+ return true, warnings;
end
end
diff --git a/core/hostmanager.lua b/core/hostmanager.lua
index 9acca517..f33a3e1e 100644
--- a/core/hostmanager.lua
+++ b/core/hostmanager.lua
@@ -133,7 +133,6 @@ function deactivate(host, reason)
for remotehost, session in pairs(host_session.s2sout) do
if session.close then
log("debug", "Closing outgoing connection to %s", remotehost);
- if session.srv_hosts then session.srv_hosts = nil; end
session:close(reason);
end
end
diff --git a/core/loggingmanager.lua b/core/loggingmanager.lua
index cfa8246a..d8e557f9 100644
--- a/core/loggingmanager.lua
+++ b/core/loggingmanager.lua
@@ -14,10 +14,14 @@ local io_open = io.open;
local math_max, rep = math.max, string.rep;
local os_date = os.date;
local getstyle, getstring = require "util.termcolours".getstyle, require "util.termcolours".getstring;
+local st = require "util.stanza";
local config = require "core.configmanager";
local logger = require "util.logger";
+local have_pposix, pposix = pcall(require, "util.pposix");
+have_pposix = have_pposix and pposix._VERSION == "0.4.0";
+
local _ENV = nil;
-- luacheck: std none
@@ -33,6 +37,8 @@ local log_sink_types = setmetatable({}, { __newindex = function (t, k, v) rawset
local get_levels;
local logging_levels = { "debug", "info", "warn", "error" }
+local function id(x) return x end
+
-- Put a rule into action. Requires that the sink type has already been registered.
-- This function is called automatically when a new sink type is added [see apply_sink_rules()]
local function add_rule(sink_config)
@@ -181,15 +187,16 @@ local function log_to_file(sink_config, logfile)
-- Column width for "source" (used by stdout and console)
local sourcewidth = sink_config.source_width;
+ local filter = sink_config.filter or id;
if sourcewidth then
return function (name, level, message, ...)
sourcewidth = math_max(#name+2, sourcewidth);
- write(logfile, timestamps and os_date(timestamps) or "", name, rep(" ", sourcewidth-#name), level, "\t", format(message, ...), "\n");
+ write(logfile, timestamps and os_date(timestamps) or "", name, rep(" ", sourcewidth-#name), level, "\t", filter(format(message, ...)), "\n");
end
else
return function (name, level, message, ...)
- write(logfile, timestamps and os_date(timestamps) or "", name, "\t", level, "\t", format(message, ...), "\n");
+ write(logfile, timestamps and os_date(timestamps) or "", name, "\t", level, "\t", filter(format(message, ...)), "\n");
end
end
end
@@ -206,22 +213,25 @@ local function log_to_stdout(sink_config)
end
log_sink_types.stdout = log_to_stdout;
-local do_pretty_printing = true;
+local do_pretty_printing = not have_pposix or pposix.isatty(stdout);
-local logstyles;
+local logstyles, pretty;
if do_pretty_printing then
logstyles = {};
logstyles["info"] = getstyle("bold");
logstyles["warn"] = getstyle("bold", "yellow");
logstyles["error"] = getstyle("bold", "red");
+
+ pretty = st.pretty_print;
end
local function log_to_console(sink_config)
-- Really if we don't want pretty colours then just use plain stdout
- local logstdout = log_to_stdout(sink_config);
if not do_pretty_printing then
- return logstdout;
+ return log_to_stdout(sink_config);
end
+ sink_config.filter = pretty;
+ local logstdout = log_to_stdout(sink_config);
return function (name, level, message, ...)
local logstyle = logstyles[level];
if logstyle then
@@ -232,6 +242,22 @@ local function log_to_console(sink_config)
end
log_sink_types.console = log_to_console;
+if have_pposix then
+ local syslog_opened;
+ local function log_to_syslog(sink_config) -- luacheck: ignore 212/sink_config
+ if not syslog_opened then
+ local facility = sink_config.syslog_facility or config.get("*", "syslog_facility");
+ pposix.syslog_open(sink_config.syslog_name or "prosody", facility);
+ syslog_opened = true;
+ end
+ local syslog = pposix.syslog_log;
+ return function (name, level, message, ...)
+ syslog(level, name, format(message, ...));
+ end;
+ end
+ log_sink_types.syslog = log_to_syslog;
+end
+
local function register_sink_type(name, sink_maker)
local old_sink_maker = log_sink_types[name];
log_sink_types[name] = sink_maker;
diff --git a/core/moduleapi.lua b/core/moduleapi.lua
index 10f9f04d..59417027 100644
--- a/core/moduleapi.lua
+++ b/core/moduleapi.lua
@@ -14,13 +14,19 @@ local pluginloader = require "util.pluginloader";
local timer = require "util.timer";
local resolve_relative_path = require"util.paths".resolve_relative_path;
local st = require "util.stanza";
+local cache = require "util.cache";
+local errors = require "util.error";
+local promise = require "util.promise";
+local time_now = require "util.time".now;
+local format = require "util.format".format;
+local jid_node = require "util.jid".node;
local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat;
local error, setmetatable, type = error, setmetatable, type;
local ipairs, pairs, select = ipairs, pairs, select;
local tonumber, tostring = tonumber, tostring;
local require = require;
-local pack = table.pack or function(...) return {n=select("#",...), ...}; end -- table.pack is only in 5.2
+local pack = table.pack or require "util.table".pack; -- table.pack is only in 5.2
local unpack = table.unpack or unpack; --luacheck: ignore 113 -- renamed in 5.2
local prosody = prosody;
@@ -361,6 +367,100 @@ function api:send(stanza, origin)
return core_post_stanza(origin or hosts[self.host], stanza);
end
+function api:send_iq(stanza, origin, timeout)
+ local iq_cache = self._iq_cache;
+ if not iq_cache then
+ iq_cache = cache.new(256, function (_, iq)
+ iq.reject(errors.new({
+ type = "wait", condition = "resource-constraint",
+ text = "evicted from iq tracking cache"
+ }));
+ end);
+ self._iq_cache = iq_cache;
+ end
+
+ local event_type;
+ if not jid_node(stanza.attr.from) then
+ event_type = "host";
+ else -- assume bare since we can't hook full jids
+ event_type = "bare";
+ end
+ local result_event = "iq-result/"..event_type.."/"..stanza.attr.id;
+ local error_event = "iq-error/"..event_type.."/"..stanza.attr.id;
+ local cache_key = event_type.."/"..stanza.attr.id;
+
+ local p = promise.new(function (resolve, reject)
+ local function result_handler(event)
+ if event.stanza.attr.from == stanza.attr.to then
+ resolve(event);
+ return true;
+ end
+ end
+
+ local function error_handler(event)
+ if event.stanza.attr.from == stanza.attr.to then
+ reject(errors.from_stanza(event.stanza, event));
+ return true;
+ end
+ end
+
+ if iq_cache:get(cache_key) then
+ reject(errors.new({
+ type = "modify", condition = "conflict",
+ text = "IQ stanza id attribute already used",
+ }));
+ return;
+ end
+
+ self:hook(result_event, result_handler);
+ self:hook(error_event, error_handler);
+
+ local timeout_handle = self:add_timer(timeout or 120, function ()
+ reject(errors.new({
+ type = "wait", condition = "remote-server-timeout",
+ text = "IQ stanza timed out",
+ }));
+ end);
+
+ local ok = iq_cache:set(cache_key, {
+ reject = reject, resolve = resolve,
+ timeout_handle = timeout_handle,
+ result_handler = result_handler, error_handler = error_handler;
+ });
+
+ if not ok then
+ reject(errors.new({
+ type = "wait", condition = "internal-server-error",
+ text = "Could not store IQ tracking data"
+ }));
+ return;
+ end
+
+ local wrapped_origin = setmetatable({
+ -- XXX Needed in some cases for replies to work correctly when sending queries internally.
+ send = function (reply)
+ resolve({ stanza = reply });
+ end;
+ }, {
+ __index = origin or hosts[self.host];
+ });
+
+ self:send(stanza, wrapped_origin);
+ end);
+
+ p:finally(function ()
+ local iq = iq_cache:get(cache_key);
+ if iq then
+ self:unhook(result_event, iq.result_handler);
+ self:unhook(error_event, iq.error_handler);
+ iq.timeout_handle:stop();
+ iq_cache:set(cache_key, nil);
+ end
+ end);
+
+ return p;
+end
+
function api:broadcast(jids, stanza, iter)
for jid in (iter or it.values)(jids) do
local new_stanza = st.clone(stanza);
@@ -396,7 +496,7 @@ end
local path_sep = package.config:sub(1,1);
function api:get_directory()
- return self.path and (self.path:gsub("%"..path_sep.."[^"..path_sep.."]*$", "")) or nil;
+ return self.resource_path or self.path and (self.path:gsub("%"..path_sep.."[^"..path_sep.."]*$", "")) or nil;
end
function api:load_resource(path, mode)
@@ -408,28 +508,63 @@ function api:open_store(name, store_type)
return require"core.storagemanager".open(self.host, name or self.name, store_type);
end
-function api:measure(name, stat_type)
+function api:measure(name, stat_type, conf)
local measure = require "core.statsmanager".measure;
- return measure(stat_type, "/"..self.host.."/mod_"..self.name.."/"..name);
+ local fixed_label_key, fixed_label_value
+ if self.host ~= "*" then
+ fixed_label_key = "host"
+ fixed_label_value = self.host
+ end
+ -- new_legacy_metric takes care of scoping for us, as it does not accept
+ -- an array of labels
+ -- the prosody_ prefix is automatically added by statsmanager for legacy
+ -- metrics.
+ return measure(stat_type, "mod_"..self.name.."/"..name, conf, fixed_label_key, fixed_label_value)
+end
+
+function api:metric(type_, name, unit, description, label_keys, conf)
+ local metric = require "core.statsmanager".metric;
+ local is_scoped = self.host ~= "*"
+ if is_scoped then
+ -- prepend `host` label to label keys if this is not a global module
+ local orig_labels = label_keys
+ label_keys = array { "host" }
+ label_keys:append(orig_labels)
+ end
+ local mf = metric(type_, "prosody_mod_"..self.name.."/"..name, unit, description, label_keys, conf)
+ if is_scoped then
+ -- make sure to scope the returned metric family to the current host
+ return mf:with_partial_label(self.host)
+ end
+ return mf
end
-function api:measure_object_event(events_object, event_name, stat_name)
- local m = self:measure(stat_name or event_name, "times");
- local function handler(handlers, _event_name, _event_data)
- local finished = m();
- local ret = handlers(_event_name, _event_data);
- finished();
- return ret;
+local status_priorities = { error = 3, warn = 2, info = 1, core = 0 };
+
+function api:set_status(status_type, status_message, override)
+ local priority = status_priorities[status_type];
+ if not priority then
+ self:log("error", "set_status: Invalid status type '%s', assuming 'info'");
+ status_type, priority = "info", status_priorities.info;
+ end
+ local current_priority = status_priorities[self.status_type] or 0;
+ -- By default an 'error' status can only be overwritten by another 'error' status
+ if (current_priority >= status_priorities.error and priority < current_priority and override ~= true)
+ or (override == false and current_priority > priority) then
+ self:log("debug", "moduleapi: ignoring status [prio %d override %s]: %s", priority, override, status_message);
+ return;
end
- return self:hook_object_event(events_object, event_name, handler);
+ self.status_type, self.status_message, self.status_time = status_type, status_message, time_now();
+ self:fire_event("module-status/updated", { name = self.name });
end
-function api:measure_event(event_name, stat_name)
- return self:measure_object_event((hosts[self.host] or prosody).events.wrappers, event_name, stat_name);
+function api:log_status(level, msg, ...)
+ self:set_status(level, format(msg, ...));
+ return self:log(level, msg, ...);
end
-function api:measure_global_event(event_name, stat_name)
- return self:measure_object_event(prosody.events.wrappers, event_name, stat_name);
+function api:get_status()
+ return self.status_type, self.status_message, self.status_time;
end
return api;
diff --git a/core/modulemanager.lua b/core/modulemanager.lua
index a824d36a..9256ae4f 100644
--- a/core/modulemanager.lua
+++ b/core/modulemanager.lua
@@ -10,6 +10,7 @@ local logger = require "util.logger";
local log = logger.init("modulemanager");
local config = require "core.configmanager";
local pluginloader = require "util.pluginloader";
+local envload = require "util.envload";
local set = require "util.set";
local new_multitable = require "util.multitable".new;
@@ -22,9 +23,27 @@ local xpcall = require "util.xpcall".xpcall;
local debug_traceback = debug.traceback;
local setmetatable, rawget = setmetatable, rawget;
local ipairs, pairs, type, t_insert = ipairs, pairs, type, table.insert;
-
-local autoload_modules = {prosody.platform, "presence", "message", "iq", "offline", "c2s", "s2s", "s2s_auth_certs"};
-local component_inheritable_modules = {"tls", "saslauth", "dialback", "iq", "s2s"};
+local lua_version = _VERSION:match("5%.%d$");
+
+local autoload_modules = {
+ prosody.platform,
+ "presence",
+ "message",
+ "iq",
+ "offline",
+ "c2s",
+ "s2s",
+ "s2s_auth_certs",
+};
+local component_inheritable_modules = {
+ "tls",
+ "saslauth",
+ "dialback",
+ "iq",
+ "s2s",
+ "s2s_bidi",
+ "server_contact_info",
+};
-- We need this to let modules access the real global namespace
local _G = _G;
@@ -174,11 +193,48 @@ local function do_load_module(host, module_name, state)
local mod, err = pluginloader.load_code(module_name, nil, pluginenv);
if not mod then
log("error", "Unable to load module '%s': %s", module_name or "nil", err or "nil");
+ api_instance:set_status("error", "Failed to load (see log)");
return nil, err;
end
api_instance.path = err;
+ local custom_plugins = prosody.paths.installer;
+ if custom_plugins and err:sub(1, #custom_plugins+1) == custom_plugins.."/" then
+ -- Stage 1: Make it work (you are here)
+ -- Stage 2: Make it less hacky (TODO)
+ local manifest = {};
+ local luarocks_path = custom_plugins.."/lib/luarocks/rocks-"..lua_version;
+ local manifest_filename = luarocks_path.."/manifest";
+ local load_manifest, err = envload.envloadfile(manifest_filename, manifest);
+ if not load_manifest then
+ -- COMPAT Luarocks 2.x
+ log("debug", "Could not load LuaRocks 3.x manifest, trying 2.x", err);
+ luarocks_path = custom_plugins.."/lib/luarocks/rocks";
+ manifest_filename = luarocks_path.."/manifest";
+ load_manifest, err = envload.envloadfile(manifest_filename, manifest);
+ end
+ if not load_manifest then
+ log("error", "Could not load manifest of installed plugins: %s", err, load_manifest);
+ else
+ local ok, err = xpcall(load_manifest, debug_traceback);
+ if not ok then
+ log("error", "Could not load manifest of installed plugins: %s", err);
+ elseif type(manifest.modules) ~= "table" then
+ log("debug", "Expected 'table' but manifest.modules = %q", manifest.modules);
+ log("error", "Can't look up resource path for mod_%s because '%s' does not appear to be a LuaRocks manifest", module_name, manifest_filename);
+ else
+ local versions = manifest.modules["mod_"..module_name];
+ if type(versions) == "table" and versions[1] then
+ -- Not going to deal with multiple installed versions
+ api_instance.resource_path = luarocks_path.."/"..versions[1];
+ else
+ log("debug", "mod_%s does not appear in the installation manifest", module_name);
+ end
+ end
+ end
+ end
+
modulemap[host][module_name] = pluginenv;
local ok, err = xpcall(mod, debug_traceback);
if ok then
@@ -187,6 +243,7 @@ local function do_load_module(host, module_name, state)
ok, err = call_module_method(pluginenv, "load");
if not ok then
log("warn", "Error loading module '%s' on '%s': %s", module_name, host, err or "nil");
+ api_instance:set_status("warn", "Error during load (see log)");
end
end
api_instance.reloading, api_instance.saved_state = nil, nil;
@@ -209,6 +266,9 @@ local function do_load_module(host, module_name, state)
if not ok then
modulemap[api_instance.host][module_name] = nil;
log("error", "Error initializing module '%s' on '%s': %s", module_name, host, err or "nil");
+ api_instance:set_status("warn", "Error during load (see log)");
+ else
+ api_instance:set_status("core", "Loaded", false);
end
return ok and pluginenv, err;
end
@@ -225,7 +285,8 @@ local function do_reload_module(host, name)
local saved;
if module_has_method(mod, "save") then
- local ok, ret, err = call_module_method(mod, "save");
+ -- FIXME What goes in 'err' here?
+ local ok, ret, err = call_module_method(mod, "save"); -- luacheck: ignore 211/err
if ok then
saved = ret;
else
diff --git a/core/portmanager.lua b/core/portmanager.lua
index bed5eca5..95d42b77 100644
--- a/core/portmanager.lua
+++ b/core/portmanager.lua
@@ -9,7 +9,8 @@ local set = require "util.set";
local table = table;
local setmetatable, rawset, rawget = setmetatable, rawset, rawget;
-local type, tonumber, tostring, ipairs = type, tonumber, tostring, ipairs;
+local type, tonumber, ipairs = type, tonumber, ipairs;
+local pairs = pairs;
local prosody = prosody;
local fire_event = prosody.events.fire_event;
@@ -64,6 +65,20 @@ local function error_to_friendly_message(service_name, port, err) --luacheck: ig
return friendly_message;
end
+local function get_port_ssl_ctx(port, interface, config_prefix, service_info)
+ local global_ssl_config = config.get("*", "ssl") or {};
+ local prefix_ssl_config = config.get("*", config_prefix.."ssl") or global_ssl_config;
+ log("debug", "Creating context for direct TLS service %s on port %d", service_info.name, port);
+ local ssl, err, cfg = certmanager.create_context(service_info.name.." port "..port, "server",
+ prefix_ssl_config[interface],
+ prefix_ssl_config[port],
+ prefix_ssl_config,
+ service_info.ssl_config or {},
+ global_ssl_config[interface],
+ global_ssl_config[port]);
+ return ssl, cfg, err;
+end
+
--- Public API
local function activate(service_name)
@@ -95,31 +110,22 @@ local function activate(service_name)
}
bind_ports = set.new(type(bind_ports) ~= "table" and { bind_ports } or bind_ports );
- local mode, ssl = listener.default_mode or default_mode;
+ local mode = listener.default_mode or default_mode;
local hooked_ports = {};
for interface in bind_interfaces do
for port in bind_ports do
local port_number = tonumber(port);
if not port_number then
- log("error", "Invalid port number specified for service '%s': %s", service_info.name, tostring(port));
+ log("error", "Invalid port number specified for service '%s': %s", service_info.name, port);
elseif #active_services:search(nil, interface, port_number) > 0 then
log("error", "Multiple services configured to listen on the same port ([%s]:%d): %s, %s", interface, port,
active_services:search(nil, interface, port)[1][1].service.name or "<unnamed>", service_name or "<unnamed>");
else
- local err;
+ local ssl, cfg, err;
-- Create SSL context for this service/port
if service_info.encryption == "ssl" then
- local global_ssl_config = config.get("*", "ssl") or {};
- local prefix_ssl_config = config.get("*", config_prefix.."ssl") or global_ssl_config;
- log("debug", "Creating context for direct TLS service %s on port %d", service_info.name, port);
- ssl, err = certmanager.create_context(service_info.name.." port "..port, "server",
- prefix_ssl_config[interface],
- prefix_ssl_config[port],
- prefix_ssl_config,
- service_info.ssl_config or {},
- global_ssl_config[interface],
- global_ssl_config[port]);
+ ssl, cfg, err = get_port_ssl_ctx(port, interface, config_prefix, service_info);
if not ssl then
log("error", "Error binding encrypted port for %s: %s", service_info.name,
error_to_friendly_message(service_name, port_number, err) or "unknown error");
@@ -127,7 +133,12 @@ local function activate(service_name)
end
if not err then
-- Start listening on interface+port
- local handler, err = server.addserver(interface, port_number, listener, mode, ssl);
+ local handler, err = server.listen(interface, port_number, listener, {
+ read_size = mode,
+ tls_ctx = ssl,
+ tls_direct = service_info.encryption == "ssl";
+ sni_hosts = {},
+ });
if not handler then
log("error", "Failed to open server port %d on %s, %s", port_number, interface,
error_to_friendly_message(service_name, port_number, err));
@@ -137,6 +148,7 @@ local function activate(service_name)
active_services:add(service_name, interface, port_number, {
server = handler;
service = service_info;
+ tls_cfg = cfg;
});
end
end
@@ -163,7 +175,7 @@ end
local function register_service(service_name, service_info)
table.insert(services[service_name], service_info);
- if not active_services:get(service_name) then
+ if not active_services:get(service_name) and prosody.process_type == "prosody" then
log("debug", "No active service for %s, activating...", service_name);
local ok, err = activate(service_name);
if not ok then
@@ -222,15 +234,75 @@ end
-- Event handlers
+local function add_sni_host(host, service)
+ log("debug", "Gathering certificates for SNI for host %s, %s service", host, service or "default");
+ for name, interface, port, n, active_service --luacheck: ignore 213
+ in active_services:iter(service, nil, nil, nil) do
+ if active_service.server.hosts and active_service.tls_cfg then
+ local config_prefix = (active_service.config_prefix or name).."_";
+ if config_prefix == "_" then config_prefix = ""; end
+ local prefix_ssl_config = config.get(host, config_prefix.."ssl");
+ local alternate_host = name and config.get(host, name.."_host");
+ if not alternate_host and name == "https" then
+ -- TODO should this be some generic thing? e.g. in the service definition
+ alternate_host = config.get(host, "http_host");
+ end
+ local autocert = certmanager.find_host_cert(alternate_host or host);
+ -- luacheck: ignore 211/cfg
+ local ssl, err, cfg = certmanager.create_context(host, "server", prefix_ssl_config, autocert, active_service.tls_cfg);
+ if ssl then
+ active_service.server.hosts[alternate_host or host] = ssl;
+ else
+ log("error", "Error creating TLS context for SNI host %s: %s", host, err);
+ end
+ end
+ end
+end
prosody.events.add_handler("item-added/net-provider", function (event)
local item = event.item;
register_service(item.name, item);
+ for host in pairs(prosody.hosts) do
+ add_sni_host(host, item.name);
+ end
end);
prosody.events.add_handler("item-removed/net-provider", function (event)
local item = event.item;
unregister_service(item.name, item);
end);
+prosody.events.add_handler("host-activated", add_sni_host);
+prosody.events.add_handler("host-deactivated", function (host)
+ for name, interface, port, n, active_service --luacheck: ignore 213
+ in active_services:iter(nil, nil, nil, nil) do
+ if active_service.tls_cfg then
+ active_service.server.hosts[host] = nil;
+ end
+ end
+end);
+
+prosody.events.add_handler("config-reloaded", function ()
+ for service_name, interface, port, _, active_service in active_services:iter(nil, nil, nil, nil) do
+ if active_service.tls_cfg then
+ local service_info = active_service.service;
+ local config_prefix = (service_info.config_prefix or service_name).."_";
+ if config_prefix == "_" then
+ config_prefix = "";
+ end
+ local ssl, cfg, err = get_port_ssl_ctx(port, interface, config_prefix, service_info);
+ if ssl then
+ active_service.server:set_sslctx(ssl);
+ active_service.tls_cfg = cfg;
+ else
+ log("error", "Error reloading certificate for encrypted port for %s: %s", service_info.name,
+ error_to_friendly_message(service_name, port, err) or "unknown error");
+ end
+ end
+ end
+ for host in pairs(prosody.hosts) do
+ add_sni_host(host, nil);
+ end
+end, -1);
+
return {
activate = activate;
deactivate = deactivate;
diff --git a/core/rostermanager.lua b/core/rostermanager.lua
index 7bfad0a0..7b104339 100644
--- a/core/rostermanager.lua
+++ b/core/rostermanager.lua
@@ -285,15 +285,15 @@ end
function is_contact_pending_in(username, host, jid)
local roster = load_roster(username, host);
- return roster[false].pending[jid];
+ return roster[false].pending[jid] ~= nil;
end
-local function set_contact_pending_in(username, host, jid)
+local function set_contact_pending_in(username, host, jid, stanza)
local roster = load_roster(username, host);
local item = roster[jid];
if item and (item.subscription == "from" or item.subscription == "both") then
return; -- false
end
- roster[false].pending[jid] = true;
+ roster[false].pending[jid] = st.is_stanza(stanza) and st.preserialize(stanza) or true;
return save_roster(username, host, roster, jid);
end
function is_contact_pending_out(username, host, jid)
@@ -301,6 +301,11 @@ function is_contact_pending_out(username, host, jid)
local item = roster[jid];
return item and item.ask;
end
+local function is_contact_preapproved(username, host, jid)
+ local roster = load_roster(username, host);
+ local item = roster[jid];
+ return item and (item.approved == "true");
+end
local function set_contact_pending_out(username, host, jid) -- subscribe
local roster = load_roster(username, host);
local item = roster[jid];
@@ -331,9 +336,10 @@ local function unsubscribe(username, host, jid)
return save_roster(username, host, roster, jid);
end
local function subscribed(username, host, jid)
+ local roster = load_roster(username, host);
+ local item = roster[jid];
+
if is_contact_pending_in(username, host, jid) then
- local roster = load_roster(username, host);
- local item = roster[jid];
if not item then -- FIXME should roster item be auto-created?
item = {subscription = "none", groups = {}};
roster[jid] = item;
@@ -345,7 +351,17 @@ local function subscribed(username, host, jid)
end
roster[false].pending[jid] = nil;
return save_roster(username, host, roster, jid);
- end -- TODO else implement optional feature pre-approval (ask = subscribed)
+ elseif not item or item.subscription == "none" or item.subscription == "to" then
+ -- Contact is not subscribed and has not sent a subscription request.
+ -- We store a pre-approval as per RFC6121 3.4
+ if not item then
+ item = {subscription = "none", groups = {}};
+ roster[jid] = item;
+ end
+ item.approved = "true";
+ log("debug", "Storing preapproval for %s", jid);
+ return save_roster(username, host, roster, jid);
+ end
end
local function unsubscribed(username, host, jid)
local roster = load_roster(username, host);
@@ -403,6 +419,7 @@ return {
set_contact_pending_in = set_contact_pending_in;
is_contact_pending_out = is_contact_pending_out;
set_contact_pending_out = set_contact_pending_out;
+ is_contact_preapproved = is_contact_preapproved;
unsubscribe = unsubscribe;
subscribed = subscribed;
unsubscribed = unsubscribed;
diff --git a/core/s2smanager.lua b/core/s2smanager.lua
index 58269c49..49a5adae 100644
--- a/core/s2smanager.lua
+++ b/core/s2smanager.lua
@@ -9,10 +9,10 @@
local hosts = prosody.hosts;
-local tostring, pairs, setmetatable
- = tostring, pairs, setmetatable;
+local pairs, setmetatable = pairs, setmetatable;
local logger_init = require "util.logger".init;
+local sessionlib = require "util.session";
local log = logger_init("s2smanager");
@@ -26,30 +26,45 @@ local _ENV = nil;
-- luacheck: std none
local function new_incoming(conn)
- local session = { conn = conn, type = "s2sin_unauthed", direction = "incoming", hosts = {} };
- session.log = logger_init("s2sin"..tostring(session):match("[a-f0-9]+$"));
- incoming_s2s[session] = true;
- return session;
+ local host_session = sessionlib.new("s2sin");
+ sessionlib.set_id(host_session);
+ sessionlib.set_logger(host_session);
+ sessionlib.set_conn(host_session, conn);
+ host_session.direction = "incoming";
+ host_session.incoming = true;
+ host_session.hosts = {};
+ incoming_s2s[host_session] = true;
+ return host_session;
end
local function new_outgoing(from_host, to_host)
- local host_session = { to_host = to_host, from_host = from_host, host = from_host,
- notopen = true, type = "s2sout_unauthed", direction = "outgoing" };
+ local host_session = sessionlib.new("s2sout");
+ sessionlib.set_id(host_session);
+ sessionlib.set_logger(host_session);
+ host_session.to_host = to_host;
+ host_session.from_host = from_host;
+ host_session.host = from_host;
+ host_session.notopen = true;
+ host_session.direction = "outgoing";
+ host_session.outgoing = true;
+ host_session.hosts = {};
hosts[from_host].s2sout[to_host] = host_session;
- local conn_name = "s2sout"..tostring(host_session):match("[a-f0-9]*$");
- host_session.log = logger_init(conn_name);
return host_session;
end
local resting_session = { -- Resting, not dead
destroyed = true;
type = "s2s_destroyed";
+ direction = "destroyed";
open_stream = function (session)
session.log("debug", "Attempt to open stream on resting session");
end;
close = function (session)
session.log("debug", "Attempt to close already-closed session");
end;
+ reset_stream = function (session)
+ session.log("debug", "Attempt to reset stream of already-closed session");
+ end;
filter = function (type, data) return data; end; --luacheck: ignore 212/type
}; resting_session.__index = resting_session;
@@ -63,23 +78,25 @@ local function retire_session(session, reason)
session.destruction_reason = reason;
- function session.send(data) log("debug", "Discarding data sent to resting session: %s", tostring(data)); end
- function session.data(data) log("debug", "Discarding data received from resting session: %s", tostring(data)); end
+ function session.send(data) log("debug", "Discarding data sent to resting session: %s", data); end
+ function session.data(data) log("debug", "Discarding data received from resting session: %s", data); end
session.thread = { run = function (_, data) return session.data(data) end };
session.sends2s = session.send;
return setmetatable(session, resting_session);
end
-local function destroy_session(session, reason)
+local function destroy_session(session, reason, bounce_reason)
if session.destroyed then return; end
- (session.log or log)("debug", "Destroying "..tostring(session.direction)
- .." session "..tostring(session.from_host).."->"..tostring(session.to_host)
- ..(reason and (": "..reason) or ""));
+ local log = session.log or log;
+ log("debug", "Destroying %s session %s->%s%s%s", session.direction, session.from_host, session.to_host, reason and ": " or "", reason or "");
if session.direction == "outgoing" then
hosts[session.from_host].s2sout[session.to_host] = nil;
- session:bounce_sendq(reason);
+ session:bounce_sendq(bounce_reason or reason);
elseif session.direction == "incoming" then
+ if session.outgoing then
+ hosts[session.to_host].s2sout[session.from_host] = nil;
+ end
incoming_s2s[session] = nil;
end
diff --git a/core/sessionmanager.lua b/core/sessionmanager.lua
index 2843001a..7f296ff1 100644
--- a/core/sessionmanager.lua
+++ b/core/sessionmanager.lua
@@ -21,6 +21,7 @@ local config_get = require "core.configmanager".get;
local resourceprep = require "util.encodings".stringprep.resourceprep;
local nodeprep = require "util.encodings".stringprep.nodeprep;
local generate_identifier = require "util.id".short;
+local sessionlib = require "util.session";
local initialize_filters = require "util.filters".initialize;
local gettime = require "socket".gettime;
@@ -29,23 +30,34 @@ local _ENV = nil;
-- luacheck: std none
local function new_session(conn)
- local session = { conn = conn, type = "c2s_unauthed", conntime = gettime() };
+ local session = sessionlib.new("c2s");
+ sessionlib.set_id(session);
+ sessionlib.set_logger(session);
+ sessionlib.set_conn(session, conn);
+
+ session.conntime = gettime();
local filter = initialize_filters(session);
local w = conn.write;
+
+ function session.rawsend(t)
+ t = filter("bytes/out", tostring(t));
+ if t then
+ local ret, err = w(conn, t);
+ if not ret then
+ session.log("debug", "Error writing to connection: %s", err);
+ return false, err;
+ end
+ end
+ return true;
+ end
+
session.send = function (t)
session.log("debug", "Sending[%s]: %s", session.type, t.top_tag and t:top_tag() or t:match("^[^>]*>?"));
if t.name then
t = filter("stanzas/out", t);
end
if t then
- t = filter("bytes/out", tostring(t));
- if t then
- local ret, err = w(conn, t);
- if not ret then
- session.log("debug", "Error writing to connection: %s", tostring(err));
- return false, err;
- end
- end
+ return session.rawsend(t);
end
return true;
end
@@ -73,8 +85,9 @@ local function retire_session(session)
end
end
- function session.send(data) log("debug", "Discarding data sent to resting session: %s", tostring(data)); return false; end
- function session.data(data) log("debug", "Discarding data received from resting session: %s", tostring(data)); end
+ function session.send(data) log("debug", "Discarding data sent to resting session: %s", data); return false; end
+ function session.rawsend(data) log("debug", "Discarding data sent to resting session: %s", data); return false; end
+ function session.data(data) log("debug", "Discarding data received from resting session: %s", data); end
session.thread = { run = function (_, data) return session.data(data) end };
return setmetatable(session, resting_session);
end
@@ -110,14 +123,15 @@ local function destroy_session(session, err)
retire_session(session);
end
-local function make_authenticated(session, username)
+local function make_authenticated(session, username, scope)
username = nodeprep(username);
if not username or #username == 0 then return nil, "Invalid username"; end
session.username = username;
if session.type == "c2s_unauthed" then
session.type = "c2s_unbound";
end
- session.log("info", "Authenticated as %s@%s", username or "(unknown)", session.host or "(unknown)");
+ session.auth_scope = scope;
+ session.log("info", "Authenticated as %s@%s", username, session.host or "(unknown)");
return true;
end
@@ -138,7 +152,7 @@ local function bind_resource(session, resource)
resource = event_payload.resource;
end
- resource = resourceprep(resource);
+ resource = resourceprep(resource or "", true);
resource = resource ~= "" and resource or generate_identifier();
--FIXME: Randomly-generated resources must be unique per-user, and never conflict with existing
diff --git a/core/stanza_router.lua b/core/stanza_router.lua
index f5a34f59..b54ea1ab 100644
--- a/core/stanza_router.lua
+++ b/core/stanza_router.lua
@@ -12,6 +12,7 @@ local hosts = _G.prosody.hosts;
local tostring = tostring;
local st = require "util.stanza";
local jid_split = require "util.jid".split;
+local jid_host = require "util.jid".host;
local jid_prepped_split = require "util.jid".prepped_split;
local full_sessions = _G.prosody.full_sessions;
@@ -27,7 +28,7 @@ local function handle_unhandled_stanza(host, origin, stanza) --luacheck: ignore
local st_type = stanza.attr.type;
if st_type == "error" or (name == "iq" and st_type == "result") then
if st_type == "error" then
- local err_type, err_condition, err_message = stanza:get_error();
+ local err_type, err_condition, err_message = stanza:get_error(); -- luacheck: ignore 211/err_message
log("debug", "Discarding unhandled error %s (%s, %s) from %s: %s",
name, err_type, err_condition or "unknown condition", origin_type, stanza:top_tag());
else
@@ -81,7 +82,7 @@ function core_process_stanza(origin, stanza)
local to_bare, from_bare;
if to then
if full_sessions[to] or bare_sessions[to] or hosts[to] then
- node, host = jid_split(to); -- TODO only the host is needed, optimize
+ host = jid_host(to);
else
node, host, resource = jid_prepped_split(to);
if not host then
@@ -111,8 +112,8 @@ function core_process_stanza(origin, stanza)
stanza.attr.from = from;
end
- if (origin.type == "s2sin" or origin.type == "c2s" or origin.type == "component") and xmlns == nil then
- if origin.type == "s2sin" and not origin.dummy then
+ if (origin.type == "s2sin" or origin.type == "s2sout" or origin.type == "c2s" or origin.type == "component") and xmlns == nil then
+ if (origin.type == "s2sin" or origin.type == "s2sout") and not origin.dummy then
local host_status = origin.hosts[from_host];
if not host_status or not host_status.authed then -- remote server trying to impersonate some other server?
log("warn", "Received a stanza claiming to be from %s, over a stream authed for %s!", from_host, origin.from_host);
@@ -171,8 +172,15 @@ function core_post_stanza(origin, stanza, preevents)
end
end
- local event_data = {origin=origin, stanza=stanza};
+ local event_data = {origin=origin, stanza=stanza, to_self=to_self};
+
if preevents then -- c2s connection
+ local result = hosts[origin.host].events.fire_event("pre-stanza", event_data);
+ if result ~= nil then
+ log("debug", "Stanza rejected by pre-stanza handler: %s", event_data.reason or "unknown reason");
+ return;
+ end
+
if hosts[origin.host].events.fire_event('pre-'..stanza.name..to_type, event_data) then return; end -- do preprocessing
end
local h = hosts[to_bare] or hosts[host or origin.host];
@@ -186,25 +194,25 @@ function core_post_stanza(origin, stanza, preevents)
end
function core_route_stanza(origin, stanza)
- local node, host, resource = jid_split(stanza.attr.to);
- local from_node, from_host, from_resource = jid_split(stanza.attr.from);
+ local to_host = jid_host(stanza.attr.to);
+ local from_host = jid_host(stanza.attr.from);
-- Auto-detect origin if not specified
origin = origin or hosts[from_host];
if not origin then return false; end
- if hosts[host] then
+ if hosts[to_host] then
-- old stanza routing code removed
core_post_stanza(origin, stanza);
else
local host_session = hosts[from_host];
if not host_session then
- log("error", "No hosts[from_host] (please report): %s", tostring(stanza));
+ log("error", "No hosts[from_host] (please report): %s", stanza);
else
local xmlns = stanza.attr.xmlns;
stanza.attr.xmlns = nil;
local routed = host_session.events.fire_event("route/remote", {
- origin = origin, stanza = stanza, from_host = from_host, to_host = host });
+ origin = origin, stanza = stanza, from_host = from_host, to_host = to_host });
stanza.attr.xmlns = xmlns; -- reset
if not routed then
log("debug", "Could not route stanza to remote");
diff --git a/core/statsmanager.lua b/core/statsmanager.lua
index 237b1dd5..686fc895 100644
--- a/core/statsmanager.lua
+++ b/core/statsmanager.lua
@@ -3,10 +3,12 @@ local config = require "core.configmanager";
local log = require "util.logger".init("stats");
local timer = require "util.timer";
local fire_event = prosody.events.fire_event;
+local array = require "util.array";
+local timed = require "util.openmetrics".timed;
local stats_interval_config = config.get("*", "statistics_interval");
local stats_interval = tonumber(stats_interval_config);
-if stats_interval_config and not stats_interval then
+if stats_interval_config and not stats_interval and stats_interval_config ~= "manual" then
log("error", "Invalid 'statistics_interval' setting, statistics will be disabled");
end
@@ -19,6 +21,9 @@ if not stats_provider and stats_interval then
elseif stats_provider and not stats_interval then
stats_interval = 60;
end
+if stats_interval_config == "manual" then
+ stats_interval = nil;
+end
local builtin_providers = {
internal = "util.statistics";
@@ -54,19 +59,152 @@ if stats == nil then
log("error", "Error loading statistics provider '%s': %s", stats_provider, stats_err);
end
-local measure, collect;
-local latest_stats = {};
-local changed_stats = {};
-local stats_extra = {};
+local measure, collect, metric, cork, uncork;
if stats then
- function measure(type, name)
- local f = assert(stats[type], "unknown stat type: "..type);
- return f(name);
+ function metric(type_, name, unit, description, labels, extra)
+ local registry = stats.metric_registry
+ local f = assert(registry[type_], "unknown metric family type: "..type_);
+ return f(registry, name, unit or "", description or "", labels, extra);
+ end
+
+ local function new_legacy_metric(stat_type, name, unit, description, fixed_label_key, fixed_label_value, extra)
+ local label_keys = array()
+ local conf = extra or {}
+ if fixed_label_key then
+ label_keys:push(fixed_label_key)
+ end
+ unit = unit or ""
+ local mf = metric(stat_type, "prosody_" .. name, unit, description, label_keys, conf);
+ if fixed_label_key then
+ mf = mf:with_partial_label(fixed_label_value)
+ end
+ return mf:with_labels()
+ end
+
+ local function unwrap_legacy_extra(extra, type_, name, unit)
+ local description = extra and extra.description or name.." "..type_
+ unit = extra and extra.unit or unit
+ return description, unit
end
- if stats_interval then
- log("debug", "Statistics enabled using %s provider, collecting every %d seconds", stats_provider_name, stats_interval);
+ -- These wrappers provide the pre-OpenMetrics interface of statsmanager
+ -- and moduleapi (module:measure).
+ local legacy_metric_wrappers = {
+ amount = function(name, fixed_label_key, fixed_label_value, extra)
+ local initial = 0
+ if type(extra) == "number" then
+ initial = extra
+ else
+ initial = extra and extra.initial or initial
+ end
+ local description, unit = unwrap_legacy_extra(extra, "amount", name)
+
+ local m = new_legacy_metric("gauge", name, unit, description, fixed_label_key, fixed_label_value)
+ m:set(initial or 0)
+ return function(v)
+ m:set(v)
+ end
+ end;
+
+ counter = function(name, fixed_label_key, fixed_label_value, extra)
+ if type(extra) == "number" then
+ -- previous versions of the API allowed passing an initial
+ -- value here; we do not allow that anymore, it is not a thing
+ -- which makes sense with counters
+ extra = nil
+ end
+
+ local description, unit = unwrap_legacy_extra(extra, "counter", name)
+
+ local m = new_legacy_metric("counter", name, unit, description, fixed_label_key, fixed_label_value)
+ m:set(0)
+ return function(v)
+ m:add(v)
+ end
+ end;
+
+ rate = function(name, fixed_label_key, fixed_label_value, extra)
+ if type(extra) == "number" then
+ -- previous versions of the API allowed passing an initial
+ -- value here; we do not allow that anymore, it is not a thing
+ -- which makes sense with counters
+ extra = nil
+ end
+
+ local description, unit = unwrap_legacy_extra(extra, "counter", name)
+
+ local m = new_legacy_metric("counter", name, unit, description, fixed_label_key, fixed_label_value)
+ m:set(0)
+ return function()
+ m:add(1)
+ end
+ end;
+
+ times = function(name, fixed_label_key, fixed_label_value, extra)
+ local conf = {}
+ if extra and extra.buckets then
+ conf.buckets = extra.buckets
+ else
+ conf.buckets = { 0.001, 0.01, 0.1, 1.0, 10.0, 100.0 }
+ end
+ local description, _ = unwrap_legacy_extra(extra, "times", name)
+
+ local m = new_legacy_metric("histogram", name, "seconds", description, fixed_label_key, fixed_label_value, conf)
+ return function()
+ return timed(m)
+ end
+ end;
+
+ sizes = function(name, fixed_label_key, fixed_label_value, extra)
+ local conf = {}
+ if extra and extra.buckets then
+ conf.buckets = extra.buckets
+ else
+ conf.buckets = { 1024, 4096, 32768, 131072, 1048576, 4194304, 33554432, 134217728, 1073741824 }
+ end
+ local description, _ = unwrap_legacy_extra(extra, "sizes", name)
+
+ local m = new_legacy_metric("histogram", name, "bytes", description, fixed_label_key, fixed_label_value, conf)
+ return function(v)
+ m:sample(v)
+ end
+ end;
+
+ distribution = function(name, fixed_label_key, fixed_label_value, extra)
+ if type(extra) == "string" then
+ -- compat with previous API
+ extra = { unit = extra }
+ end
+ local description, unit = unwrap_legacy_extra(extra, "distribution", name, "")
+ local m = new_legacy_metric("summary", name, unit, description, fixed_label_key, fixed_label_value)
+ return function(v)
+ m:sample(v)
+ end
+ end;
+ };
+
+ -- Argument order switched here to support the legacy statsmanager.measure
+ -- interface.
+ function measure(stat_type, name, extra, fixed_label_key, fixed_label_value)
+ local wrapper = assert(legacy_metric_wrappers[stat_type], "unknown legacy metric type "..stat_type)
+ return wrapper(name, fixed_label_key, fixed_label_value, extra)
+ end
+
+ if stats.cork then
+ function cork()
+ return stats:cork()
+ end
+
+ function uncork()
+ return stats:uncork()
+ end
+ else
+ function cork() end
+ function uncork() end
+ end
+
+ if stats_interval or stats_interval_config == "manual" then
local mark_collection_start = measure("times", "stats.collection");
local mark_processing_start = measure("times", "stats.processing");
@@ -74,44 +212,68 @@ if stats then
function collect()
local mark_collection_done = mark_collection_start();
fire_event("stats-update");
+ -- ensure that the backend is uncorked, in case it got stuck at
+ -- some point, to avoid infinite resource use
+ uncork()
mark_collection_done();
+ local manual_result = nil
- if stats.get_stats then
- changed_stats, stats_extra = {}, {};
- for stat_name, getter in pairs(stats.get_stats()) do
- local type, value, extra = getter();
- local old_value = latest_stats[stat_name];
- latest_stats[stat_name] = value;
- if value ~= old_value then
- changed_stats[stat_name] = value;
- end
- if extra then
- stats_extra[stat_name] = extra;
- end
- end
+ if stats.metric_registry then
+ -- only if supported by the backend, we fire the event which
+ -- provides the current metric values
local mark_processing_done = mark_processing_start();
- fire_event("stats-updated", { stats = latest_stats, changed_stats = changed_stats, stats_extra = stats_extra });
+ local metric_registry = stats.metric_registry;
+ fire_event("openmetrics-updated", { metric_registry = metric_registry })
mark_processing_done();
+ manual_result = metric_registry;
end
- return stats_interval;
+
+ return stats_interval, manual_result;
+ end
+ if stats_interval then
+ log("debug", "Statistics enabled using %s provider, collecting every %d seconds", stats_provider_name, stats_interval);
+ timer.add_task(stats_interval, collect);
+ prosody.events.add_handler("server-started", function () collect() end, -1);
+ prosody.events.add_handler("server-stopped", function () collect() end, -1);
+ else
+ log("debug", "Statistics enabled using %s provider, no scheduled collection", stats_provider_name);
end
- timer.add_task(stats_interval, collect);
- prosody.events.add_handler("server-started", function () collect() end, -1);
else
log("debug", "Statistics enabled using %s provider, collection is disabled", stats_provider_name);
end
else
log("debug", "Statistics disabled");
function measure() return measure; end
+
+ local dummy_mt = {}
+ function dummy_mt.__newindex()
+ end
+ function dummy_mt:__index()
+ return self
+ end
+ function dummy_mt:__call()
+ return self
+ end
+ local dummy = {}
+ setmetatable(dummy, dummy_mt)
+
+ function metric() return dummy; end
+ function cork() end
+ function uncork() end
end
+local exported_collect = nil;
+if stats_interval_config == "manual" then
+ exported_collect = collect;
+end
return {
+ collect = exported_collect;
measure = measure;
- get_stats = function ()
- return latest_stats, changed_stats, stats_extra;
- end;
- get = function (name)
- return latest_stats[name], stats_extra[name];
+ cork = cork;
+ uncork = uncork;
+ metric = metric;
+ get_metric_registry = function ()
+ return stats and stats.metric_registry or nil
end;
};
diff --git a/core/storagemanager.lua b/core/storagemanager.lua
index dea71733..856acad3 100644
--- a/core/storagemanager.lua
+++ b/core/storagemanager.lua
@@ -167,6 +167,39 @@ local map_shim_mt = {
return self.keyval_store:set(username, current);
end;
remove = {};
+ get_all = function (self, key)
+ if type(key) ~= "string" or key == "" then
+ return nil, "get_all only supports non-empty string keys";
+ end
+ local ret;
+ for username in self.keyval_store:users() do
+ local key_data = self:get(username, key);
+ if key_data then
+ if not ret then
+ ret = {};
+ end
+ ret[username] = key_data;
+ end
+ end
+ return ret;
+ end;
+ delete_all = function (self, key)
+ if type(key) ~= "string" or key == "" then
+ return nil, "delete_all only supports non-empty string keys";
+ end
+ local data = { [key] = self.remove };
+ local last_err;
+ for username in self.keyval_store:users() do
+ local ok, err = self:set_keys(username, data);
+ if not ok then
+ last_err = err;
+ end
+ end
+ if last_err then
+ return nil, last_err;
+ end
+ return true;
+ end;
};
}
diff --git a/core/usermanager.lua b/core/usermanager.lua
index bb5669cf..55faa0c9 100644
--- a/core/usermanager.lua
+++ b/core/usermanager.lua
@@ -9,12 +9,14 @@
local modulemanager = require "core.modulemanager";
local log = require "util.logger".init("usermanager");
local type = type;
-local ipairs = ipairs;
+local it = require "util.iterators";
local jid_bare = require "util.jid".bare;
+local jid_split = require "util.jid".split;
local jid_prep = require "util.jid".prep;
local config = require "core.configmanager";
local sasl_new = require "util.sasl".new;
local storagemanager = require "core.storagemanager";
+local set = require "util.set";
local prosody = _G.prosody;
local hosts = prosody.hosts;
@@ -34,10 +36,36 @@ local function new_null_provider()
});
end
+local global_admins_config = config.get("*", "admins");
+if type(global_admins_config) ~= "table" then
+ global_admins_config = nil; -- TODO: factor out moduleapi magic config handling and use it here
+end
+local global_admins = set.new(global_admins_config) / jid_prep;
+
+local admin_role = { ["prosody:admin"] = true };
+local global_authz_provider = {
+ get_user_roles = function (user) end; --luacheck: ignore 212/user
+ get_jid_roles = function (jid)
+ if global_admins:contains(jid) then
+ return admin_role;
+ end
+ end;
+ get_jids_with_role = function (role)
+ if role ~= "prosody:admin" then return {}; end
+ return it.to_array(global_admins);
+ end;
+};
+
local provider_mt = { __index = new_null_provider() };
local function initialize_host(host)
local host_session = hosts[host];
+
+ local authz_provider_name = config.get(host, "authorization") or "internal";
+
+ local authz_mod = modulemanager.load(host, "authz_"..authz_provider_name);
+ host_session.authz = authz_mod or global_authz_provider;
+
if host_session.type ~= "local" then return; end
host_session.events.add_handler("item-added/auth-provider", function (event)
@@ -66,6 +94,7 @@ local function initialize_host(host)
if auth_provider ~= "null" then
modulemanager.load(host, "auth_"..auth_provider);
end
+
end;
prosody.events.add_handler("host-activated", initialize_host, 100);
@@ -113,45 +142,64 @@ local function get_provider(host)
return hosts[host].users;
end
-local function is_admin(jid, host)
+local function get_roles(jid, host)
if host and not hosts[host] then return false; end
if type(jid) ~= "string" then return false; end
jid = jid_bare(jid);
host = host or "*";
- local host_admins = config.get(host, "admins");
- local global_admins = config.get("*", "admins");
-
- if host_admins and host_admins ~= global_admins then
- if type(host_admins) == "table" then
- for _,admin in ipairs(host_admins) do
- if jid_prep(admin) == jid then
- return true;
- end
- end
- elseif host_admins then
- log("error", "Option 'admins' for host '%s' is not a list", host);
- end
- end
+ local actor_user, actor_host = jid_split(jid);
+ local roles;
- if global_admins then
- if type(global_admins) == "table" then
- for _,admin in ipairs(global_admins) do
- if jid_prep(admin) == jid then
- return true;
- end
- end
- elseif global_admins then
- log("error", "Global option 'admins' is not a list");
- end
+ local authz_provider = (host ~= "*" and hosts[host].authz) or global_authz_provider;
+
+ if actor_user and actor_host == host then -- Local user
+ roles = authz_provider.get_user_roles(actor_user);
+ else -- Remote user/JID
+ roles = authz_provider.get_jid_roles(jid);
end
- -- Still not an admin, check with auth provider
- if host ~= "*" and hosts[host].users and hosts[host].users.is_admin then
- return hosts[host].users.is_admin(jid);
+ return roles;
+end
+
+local function set_roles(jid, host, roles)
+ if host and not hosts[host] then return false; end
+ if type(jid) ~= "string" then return false; end
+
+ jid = jid_bare(jid);
+ host = host or "*";
+
+ local actor_user, actor_host = jid_split(jid);
+
+ local authz_provider = (host ~= "*" and hosts[host].authz) or global_authz_provider;
+ if actor_user and actor_host == host then -- Local user
+ return authz_provider.set_user_roles(actor_user, roles)
+ else -- Remote entity
+ return authz_provider.set_jid_roles(jid, roles)
end
- return false;
+end
+
+local function is_admin(jid, host)
+ local roles = get_roles(jid, host);
+ return roles and roles["prosody:admin"];
+end
+
+local function get_users_with_role(role, host)
+ if not hosts[host] then return false; end
+ if type(role) ~= "string" then return false; end
+
+ return hosts[host].authz.get_users_with_role(role);
+end
+
+local function get_jids_with_role(role, host)
+ if host and not hosts[host] then return false; end
+ if type(role) ~= "string" then return false; end
+
+ host = host or "*";
+
+ local authz_provider = (host ~= "*" and hosts[host].authz) or global_authz_provider;
+ return authz_provider.get_jids_with_role(role);
end
return {
@@ -166,5 +214,9 @@ return {
users = users;
get_sasl_handler = get_sasl_handler;
get_provider = get_provider;
+ get_roles = get_roles;
+ set_roles = set_roles;
is_admin = is_admin;
+ get_users_with_role = get_users_with_role;
+ get_jids_with_role = get_jids_with_role;
};