aboutsummaryrefslogtreecommitdiffstats
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/certmanager.lua18
-rw-r--r--core/configmanager.lua48
-rw-r--r--core/loggingmanager.lua19
-rw-r--r--core/moduleapi.lua120
-rw-r--r--core/modulemanager.lua25
-rw-r--r--core/portmanager.lua58
-rw-r--r--core/rostermanager.lua28
-rw-r--r--core/s2smanager.lua46
-rw-r--r--core/sessionmanager.lua38
-rw-r--r--core/stanza_router.lua6
-rw-r--r--core/statsmanager.lua1
11 files changed, 346 insertions, 61 deletions
diff --git a/core/certmanager.lua b/core/certmanager.lua
index 5282a6f5..b20a0cdb 100644
--- a/core/certmanager.lua
+++ b/core/certmanager.lua
@@ -20,7 +20,6 @@ end
local configmanager = require "core.configmanager";
local log = require "util.logger".init("certmanager");
local ssl_context = ssl.context or softreq"ssl.context";
-local ssl_x509 = ssl.x509 or softreq"ssl.x509";
local ssl_newcontext = ssl.newcontext;
local new_config = require"util.sslconfig".new;
local stat = require "lfs".attributes;
@@ -106,7 +105,7 @@ local core_defaults = {
capath = "/etc/ssl/certs";
depth = 9;
protocol = "tlsv1+";
- verify = (ssl_x509 and { "peer", "client_once", }) or "none";
+ verify = "none";
options = {
cipher_server_preference = luasec_has.options.cipher_server_preference;
no_ticket = luasec_has.options.no_ticket;
@@ -123,8 +122,8 @@ local core_defaults = {
"P-521",
};
ciphers = { -- Enabled ciphers in order of preference:
- "HIGH+kEDH", -- Ephemeral Diffie-Hellman key exchange, if a 'dhparam' file is set
"HIGH+kEECDH", -- Ephemeral Elliptic curve Diffie-Hellman key exchange
+ "HIGH+kEDH", -- Ephemeral Diffie-Hellman key exchange, if a 'dhparam' file is set
"HIGH", -- Other "High strength" ciphers
-- Disabled cipher suites:
"!PSK", -- Pre-Shared Key - not used for XMPP
@@ -148,13 +147,6 @@ local path_options = { -- These we pass through resolve_path()
key = true, certificate = true, cafile = true, capath = true, dhparam = true
}
-if luasec_version < 5 and ssl_x509 then
- -- COMPAT mw/luasec-hg
- for i=1,#core_defaults.verifyext do -- Remove lsec_ prefix
- core_defaults.verify[#core_defaults.verify+1] = core_defaults.verifyext[i]:sub(6);
- end
-end
-
local function create_context(host, mode, ...)
local cfg = new_config();
cfg:apply(core_defaults);
@@ -177,8 +169,10 @@ local function create_context(host, mode, ...)
local user_ssl_config = cfg:final();
if mode == "server" then
- if not user_ssl_config.certificate then return nil, "No certificate present in SSL/TLS configuration for "..host; end
- if not user_ssl_config.key then return nil, "No key present in SSL/TLS configuration for "..host; end
+ if not user_ssl_config.certificate then
+ log("info", "No certificate present in SSL/TLS configuration for %s. SNI will be required.", host);
+ end
+ if user_ssl_config.certificate and not user_ssl_config.key then return nil, "No key present in SSL/TLS configuration for "..host; end
end
for option in pairs(path_options) do
diff --git a/core/configmanager.lua b/core/configmanager.lua
index 1e67da9b..ae0a274a 100644
--- a/core/configmanager.lua
+++ b/core/configmanager.lua
@@ -7,15 +7,16 @@
--
local _G = _G;
-local setmetatable, rawget, rawset, io, os, error, dofile, type, pairs =
- setmetatable, rawget, rawset, io, os, error, dofile, type, pairs;
-local format, math_max = string.format, math.max;
+local setmetatable, rawget, rawset, io, os, error, dofile, type, pairs, ipairs =
+ setmetatable, rawget, rawset, io, os, error, dofile, type, pairs, ipairs;
+local format, math_max, t_insert = string.format, math.max, table.insert;
local envload = require"util.envload".envload;
local deps = require"util.dependencies";
local resolve_relative_path = require"util.paths".resolve_relative_path;
local glob_to_pattern = require"util.paths".glob_to_pattern;
local path_sep = package.config:sub(1,1);
+local get_traceback_table = require "util.debug".get_traceback_table;
local encodings = deps.softreq"util.encodings";
local nameprep = encodings and encodings.stringprep.nameprep or function (host) return host:lower(); end
@@ -100,8 +101,18 @@ end
-- Built-in Lua parser
do
local pcall = _G.pcall;
+ local function get_line_number(config_file)
+ local tb = get_traceback_table(nil, 2);
+ for i = 1, #tb do
+ if tb[i].info.short_src == config_file then
+ return tb[i].info.currentline;
+ end
+ end
+ end
parser = {};
function parser.load(data, config_file, config_table)
+ local set_options = {}; -- set_options[host.."/"..option_name] = true (when the option has been set already in this file)
+ local warnings = {};
local env;
-- The ' = true' are needed so as not to set off __newindex when we assign the functions below
env = setmetatable({
@@ -115,13 +126,26 @@ do
return rawget(_G, k);
end,
__newindex = function (_, k, v)
+ local host = env.__currenthost or "*";
+ local option_path = host.."/"..k;
+ if set_options[option_path] then
+ t_insert(warnings, ("%s:%d: Duplicate option '%s'"):format(config_file, get_line_number(config_file), k));
+ end
+ set_options[option_path] = true;
set(config_table, env.__currenthost or "*", k, v);
end
});
rawset(env, "__currenthost", "*") -- Default is global
function env.VirtualHost(name)
- name = nameprep(name);
+ if not name then
+ error("Host must have a name", 2);
+ end
+ local prepped_name = nameprep(name);
+ if not prepped_name then
+ error(format("Name of Host %q contains forbidden characters", name), 0);
+ end
+ name = prepped_name;
if rawget(config_table, name) and rawget(config_table[name], "component_module") then
error(format("Host %q clashes with previously defined %s Component %q, for services use a sub-domain like conference.%s",
name, config_table[name].component_module:gsub("^%a+$", { component = "external", muc = "MUC"}), name, name), 0);
@@ -139,7 +163,14 @@ do
env.Host, env.host = env.VirtualHost, env.VirtualHost;
function env.Component(name)
- name = nameprep(name);
+ if not name then
+ error("Component must have a name", 2);
+ end
+ local prepped_name = nameprep(name);
+ if not prepped_name then
+ error(format("Name of Component %q contains forbidden characters", name), 0);
+ end
+ name = prepped_name;
if rawget(config_table, name) and rawget(config_table[name], "defined")
and not rawget(config_table[name], "component_module") then
error(format("Component %q clashes with previously defined Host %q, for services use a sub-domain like conference.%s",
@@ -195,6 +226,11 @@ do
if f then
local ret, err = parser.load(f:read("*a"), file, config_table);
if not ret then error(err:gsub("%[string.-%]", file), 0); end
+ if err then
+ for _, warning in ipairs(err) do
+ t_insert(warnings, warning);
+ end
+ end
end
if not f then error("Error loading included "..file..": "..err, 0); end
return f, err;
@@ -217,7 +253,7 @@ do
return nil, err;
end
- return true;
+ return true, warnings;
end
end
diff --git a/core/loggingmanager.lua b/core/loggingmanager.lua
index cfa8246a..85a6380b 100644
--- a/core/loggingmanager.lua
+++ b/core/loggingmanager.lua
@@ -18,6 +18,9 @@ local getstyle, getstring = require "util.termcolours".getstyle, require "util.t
local config = require "core.configmanager";
local logger = require "util.logger";
+local have_pposix, pposix = pcall(require, "util.pposix");
+have_pposix = have_pposix and pposix._VERSION == "0.4.0";
+
local _ENV = nil;
-- luacheck: std none
@@ -232,6 +235,22 @@ local function log_to_console(sink_config)
end
log_sink_types.console = log_to_console;
+if have_pposix then
+ local syslog_opened;
+ local function log_to_syslog(sink_config) -- luacheck: ignore 212/sink_config
+ if not syslog_opened then
+ local facility = sink_config.syslog_facility or config.get("*", "syslog_facility");
+ pposix.syslog_open(sink_config.syslog_name or "prosody", facility);
+ syslog_opened = true;
+ end
+ local syslog = pposix.syslog_log;
+ return function (name, level, message, ...)
+ syslog(level, name, format(message, ...));
+ end;
+ end
+ log_sink_types.syslog = log_to_syslog;
+end
+
local function register_sink_type(name, sink_maker)
local old_sink_maker = log_sink_types[name];
log_sink_types[name] = sink_maker;
diff --git a/core/moduleapi.lua b/core/moduleapi.lua
index 10f9f04d..0a8adc36 100644
--- a/core/moduleapi.lua
+++ b/core/moduleapi.lua
@@ -14,13 +14,18 @@ local pluginloader = require "util.pluginloader";
local timer = require "util.timer";
local resolve_relative_path = require"util.paths".resolve_relative_path;
local st = require "util.stanza";
+local cache = require "util.cache";
+local errutil = require "util.error";
+local promise = require "util.promise";
+local time_now = require "util.time".now;
+local format = require "util.format".format;
local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat;
local error, setmetatable, type = error, setmetatable, type;
local ipairs, pairs, select = ipairs, pairs, select;
local tonumber, tostring = tonumber, tostring;
local require = require;
-local pack = table.pack or function(...) return {n=select("#",...), ...}; end -- table.pack is only in 5.2
+local pack = table.pack or require "util.table".pack; -- table.pack is only in 5.2
local unpack = table.unpack or unpack; --luacheck: ignore 113 -- renamed in 5.2
local prosody = prosody;
@@ -361,6 +366,91 @@ function api:send(stanza, origin)
return core_post_stanza(origin or hosts[self.host], stanza);
end
+function api:send_iq(stanza, origin, timeout)
+ local iq_cache = self._iq_cache;
+ if not iq_cache then
+ iq_cache = cache.new(256, function (_, iq)
+ iq.reject(errutil.new({
+ type = "wait", condition = "resource-constraint",
+ text = "evicted from iq tracking cache"
+ }));
+ end);
+ self._iq_cache = iq_cache;
+ end
+
+ local event_type;
+ if stanza.attr.from == self.host then
+ event_type = "host";
+ else -- assume bare since we can't hook full jids
+ event_type = "bare";
+ end
+ local result_event = "iq-result/"..event_type.."/"..stanza.attr.id;
+ local error_event = "iq-error/"..event_type.."/"..stanza.attr.id;
+ local cache_key = event_type.."/"..stanza.attr.id;
+
+ local p = promise.new(function (resolve, reject)
+ local function result_handler(event)
+ if event.stanza.attr.from == stanza.attr.to then
+ resolve(event);
+ return true;
+ end
+ end
+
+ local function error_handler(event)
+ if event.stanza.attr.from == stanza.attr.to then
+ reject(errutil.from_stanza(event.stanza), event);
+ return true;
+ end
+ end
+
+ if iq_cache:get(cache_key) then
+ reject(errutil.new({
+ type = "modify", condition = "conflict",
+ text = "IQ stanza id attribute already used",
+ }));
+ return;
+ end
+
+ self:hook(result_event, result_handler);
+ self:hook(error_event, error_handler);
+
+ local timeout_handle = self:add_timer(timeout or 120, function ()
+ reject(errutil.new({
+ type = "wait", condition = "remote-server-timeout",
+ text = "IQ stanza timed out",
+ }));
+ end);
+
+ local ok = iq_cache:set(cache_key, {
+ reject = reject, resolve = resolve,
+ timeout_handle = timeout_handle,
+ result_handler = result_handler, error_handler = error_handler;
+ });
+
+ if not ok then
+ reject(errutil.new({
+ type = "wait", condition = "internal-server-error",
+ text = "Could not store IQ tracking data"
+ }));
+ return;
+ end
+
+ self:send(stanza, origin);
+ end);
+
+ p:finally(function ()
+ local iq = iq_cache:get(cache_key);
+ if iq then
+ self:unhook(result_event, iq.result_handler);
+ self:unhook(error_event, iq.error_handler);
+ iq.timeout_handle:stop();
+ iq_cache:set(cache_key, nil);
+ end
+ end);
+
+ return p;
+end
+
function api:broadcast(jids, stanza, iter)
for jid in (iter or it.values)(jids) do
local new_stanza = st.clone(stanza);
@@ -432,4 +522,32 @@ function api:measure_global_event(event_name, stat_name)
return self:measure_object_event(prosody.events.wrappers, event_name, stat_name);
end
+local status_priorities = { error = 3, warn = 2, info = 1, core = 0 };
+
+function api:set_status(status_type, status_message, override)
+ local priority = status_priorities[status_type];
+ if not priority then
+ self:log("error", "set_status: Invalid status type '%s', assuming 'info'");
+ status_type, priority = "info", status_priorities.info;
+ end
+ local current_priority = status_priorities[self.status_type] or 0;
+ -- By default an 'error' status can only be overwritten by another 'error' status
+ if (current_priority >= status_priorities.error and priority < current_priority and override ~= true)
+ or (override == false and current_priority > priority) then
+ self:log("debug", "moduleapi: ignoring status [prio %d override %s]: %s", priority, override, status_message);
+ return;
+ end
+ self.status_type, self.status_message, self.status_time = status_type, status_message, time_now();
+ self:fire_event("module-status/updated", { name = self.name });
+end
+
+function api:log_status(level, msg, ...)
+ self:set_status(level, format(msg, ...));
+ return self:log(level, msg, ...);
+end
+
+function api:get_status()
+ return self.status_type, self.status_message, self.status_time;
+end
+
return api;
diff --git a/core/modulemanager.lua b/core/modulemanager.lua
index 17602459..5a45d6b6 100644
--- a/core/modulemanager.lua
+++ b/core/modulemanager.lua
@@ -23,8 +23,24 @@ local debug_traceback = debug.traceback;
local setmetatable, rawget = setmetatable, rawget;
local ipairs, pairs, type, t_insert = ipairs, pairs, type, table.insert;
-local autoload_modules = {prosody.platform, "presence", "message", "iq", "offline", "c2s", "s2s", "s2s_auth_certs"};
-local component_inheritable_modules = {"tls", "saslauth", "dialback", "iq", "s2s"};
+local autoload_modules = {
+ prosody.platform,
+ "presence",
+ "message",
+ "iq",
+ "offline",
+ "c2s",
+ "s2s",
+ "s2s_auth_certs",
+};
+local component_inheritable_modules = {
+ "tls",
+ "saslauth",
+ "dialback",
+ "iq",
+ "s2s",
+ "s2s_bidi",
+};
-- We need this to let modules access the real global namespace
local _G = _G;
@@ -169,6 +185,7 @@ local function do_load_module(host, module_name, state)
local mod, err = pluginloader.load_code(module_name, nil, pluginenv);
if not mod then
log("error", "Unable to load module '%s': %s", module_name or "nil", err or "nil");
+ api_instance:set_status("error", "Failed to load (see log)");
return nil, err;
end
@@ -182,6 +199,7 @@ local function do_load_module(host, module_name, state)
ok, err = call_module_method(pluginenv, "load");
if not ok then
log("warn", "Error loading module '%s' on '%s': %s", module_name, host, err or "nil");
+ api_instance:set_status("warn", "Error during load (see log)");
end
end
api_instance.reloading, api_instance.saved_state = nil, nil;
@@ -204,6 +222,9 @@ local function do_load_module(host, module_name, state)
if not ok then
modulemap[api_instance.host][module_name] = nil;
log("error", "Error initializing module '%s' on '%s': %s", module_name, host, err or "nil");
+ api_instance:set_status("warn", "Error during load (see log)");
+ else
+ api_instance:set_status("core", "Loaded", false);
end
return ok and pluginenv, err;
end
diff --git a/core/portmanager.lua b/core/portmanager.lua
index bed5eca5..55868c34 100644
--- a/core/portmanager.lua
+++ b/core/portmanager.lua
@@ -9,7 +9,8 @@ local set = require "util.set";
local table = table;
local setmetatable, rawset, rawget = setmetatable, rawset, rawget;
-local type, tonumber, tostring, ipairs = type, tonumber, tostring, ipairs;
+local type, tonumber, ipairs = type, tonumber, ipairs;
+local pairs = pairs;
local prosody = prosody;
local fire_event = prosody.events.fire_event;
@@ -95,25 +96,25 @@ local function activate(service_name)
}
bind_ports = set.new(type(bind_ports) ~= "table" and { bind_ports } or bind_ports );
- local mode, ssl = listener.default_mode or default_mode;
+ local mode = listener.default_mode or default_mode;
local hooked_ports = {};
for interface in bind_interfaces do
for port in bind_ports do
local port_number = tonumber(port);
if not port_number then
- log("error", "Invalid port number specified for service '%s': %s", service_info.name, tostring(port));
+ log("error", "Invalid port number specified for service '%s': %s", service_info.name, port);
elseif #active_services:search(nil, interface, port_number) > 0 then
log("error", "Multiple services configured to listen on the same port ([%s]:%d): %s, %s", interface, port,
active_services:search(nil, interface, port)[1][1].service.name or "<unnamed>", service_name or "<unnamed>");
else
- local err;
+ local ssl, cfg, err;
-- Create SSL context for this service/port
if service_info.encryption == "ssl" then
local global_ssl_config = config.get("*", "ssl") or {};
local prefix_ssl_config = config.get("*", config_prefix.."ssl") or global_ssl_config;
log("debug", "Creating context for direct TLS service %s on port %d", service_info.name, port);
- ssl, err = certmanager.create_context(service_info.name.." port "..port, "server",
+ ssl, err, cfg = certmanager.create_context(service_info.name.." port "..port, "server",
prefix_ssl_config[interface],
prefix_ssl_config[port],
prefix_ssl_config,
@@ -127,7 +128,12 @@ local function activate(service_name)
end
if not err then
-- Start listening on interface+port
- local handler, err = server.addserver(interface, port_number, listener, mode, ssl);
+ local handler, err = server.listen(interface, port_number, listener, {
+ read_size = mode,
+ tls_ctx = ssl,
+ tls_direct = service_info.encryption == "ssl";
+ sni_hosts = {},
+ });
if not handler then
log("error", "Failed to open server port %d on %s, %s", port_number, interface,
error_to_friendly_message(service_name, port_number, err));
@@ -137,6 +143,7 @@ local function activate(service_name)
active_services:add(service_name, interface, port_number, {
server = handler;
service = service_info;
+ tls_cfg = cfg;
});
end
end
@@ -222,15 +229,54 @@ end
-- Event handlers
+local function add_sni_host(host, service)
+ -- local global_ssl_config = config.get(host, "ssl") or {};
+ for name, interface, port, n, active_service --luacheck: ignore 213
+ in active_services:iter(service, nil, nil, nil) do
+ if active_service.server.hosts and active_service.tls_cfg then
+ -- local config_prefix = (active_service.config_prefix or name).."_";
+ -- if config_prefix == "_" then
+ -- config_prefix = "";
+ -- end
+ -- local prefix_ssl_config = config.get(host, config_prefix.."ssl") or global_ssl_config;
+ -- FIXME only global 'ssl' settings are mixed in here
+ -- TODO per host and per service settings should be merged in,
+ -- without overriding the per-host certificate
+ local ssl, err, cfg = certmanager.create_context(host, "server");
+ if ssl then
+ active_service.server.hosts[host] = ssl;
+ if not active_service.tls_cfg.certificate then
+ active_service.server.tls_ctx = ssl;
+ active_service.tls_cfg = cfg;
+ end
+ else
+ log("error", "err = %q", err);
+ end
+ end
+ end
+end
prosody.events.add_handler("item-added/net-provider", function (event)
local item = event.item;
register_service(item.name, item);
+ for host in pairs(prosody.hosts) do
+ add_sni_host(host, item.name);
+ end
end);
prosody.events.add_handler("item-removed/net-provider", function (event)
local item = event.item;
unregister_service(item.name, item);
end);
+prosody.events.add_handler("host-activated", add_sni_host);
+prosody.events.add_handler("host-deactivated", function (host)
+ for name, interface, port, n, active_service --luacheck: ignore 213
+ in active_services:iter(nil, nil, nil, nil) do
+ if active_service.tls_cfg then
+ active_service.server.hosts[host] = nil;
+ end
+ end
+end);
+
return {
activate = activate;
deactivate = deactivate;
diff --git a/core/rostermanager.lua b/core/rostermanager.lua
index 61b08002..d551a1b1 100644
--- a/core/rostermanager.lua
+++ b/core/rostermanager.lua
@@ -12,6 +12,7 @@
local log = require "util.logger".init("rostermanager");
local new_id = require "util.id".short;
+local new_cache = require "util.cache".new;
local pairs = pairs;
local tostring = tostring;
@@ -111,6 +112,23 @@ local function load_roster(username, host)
else -- Attempt to load roster for non-loaded user
log("debug", "load_roster: loading for offline user: %s", jid);
end
+ local roster_cache = hosts[host] and hosts[host].roster_cache;
+ if not roster_cache then
+ if hosts[host] then
+ roster_cache = new_cache(1024);
+ hosts[host].roster_cache = roster_cache;
+ end
+ else
+ roster = roster_cache:get(jid);
+ if roster then
+ log("debug", "load_roster: cache hit");
+ roster_cache:set(jid, roster);
+ if user then user.roster = roster; end
+ return roster;
+ else
+ log("debug", "load_roster: cache miss, loading from storage");
+ end
+ end
local roster_store = storagemanager.open(host, "roster", "keyval");
local data, err = roster_store:get(username);
roster = data or {};
@@ -134,6 +152,10 @@ local function load_roster(username, host)
if not err then
hosts[host].events.fire_event("roster-load", { username = username, host = host, roster = roster });
end
+ if roster_cache and not user then
+ log("debug", "load_roster: caching loaded roster");
+ roster_cache:set(jid, roster);
+ end
return roster, err;
end
@@ -263,15 +285,15 @@ end
function is_contact_pending_in(username, host, jid)
local roster = load_roster(username, host);
- return roster[false].pending[jid];
+ return roster[false].pending[jid] ~= nil;
end
-local function set_contact_pending_in(username, host, jid)
+local function set_contact_pending_in(username, host, jid, stanza)
local roster = load_roster(username, host);
local item = roster[jid];
if item and (item.subscription == "from" or item.subscription == "both") then
return; -- false
end
- roster[false].pending[jid] = true;
+ roster[false].pending[jid] = st.is_stanza(stanza) and st.preserialize(stanza) or true;
return save_roster(username, host, roster, jid);
end
function is_contact_pending_out(username, host, jid)
diff --git a/core/s2smanager.lua b/core/s2smanager.lua
index 58269c49..7471286c 100644
--- a/core/s2smanager.lua
+++ b/core/s2smanager.lua
@@ -9,10 +9,10 @@
local hosts = prosody.hosts;
-local tostring, pairs, setmetatable
- = tostring, pairs, setmetatable;
+local pairs, setmetatable = pairs, setmetatable;
local logger_init = require "util.logger".init;
+local sessionlib = require "util.session";
local log = logger_init("s2smanager");
@@ -26,18 +26,29 @@ local _ENV = nil;
-- luacheck: std none
local function new_incoming(conn)
- local session = { conn = conn, type = "s2sin_unauthed", direction = "incoming", hosts = {} };
- session.log = logger_init("s2sin"..tostring(session):match("[a-f0-9]+$"));
- incoming_s2s[session] = true;
- return session;
+ local host_session = sessionlib.new("s2sin");
+ sessionlib.set_id(host_session);
+ sessionlib.set_logger(host_session);
+ sessionlib.set_conn(host_session, conn);
+ host_session.direction = "incoming";
+ host_session.incoming = true;
+ host_session.hosts = {};
+ incoming_s2s[host_session] = true;
+ return host_session;
end
local function new_outgoing(from_host, to_host)
- local host_session = { to_host = to_host, from_host = from_host, host = from_host,
- notopen = true, type = "s2sout_unauthed", direction = "outgoing" };
+ local host_session = sessionlib.new("s2sout");
+ sessionlib.set_id(host_session);
+ sessionlib.set_logger(host_session);
+ host_session.to_host = to_host;
+ host_session.from_host = from_host;
+ host_session.host = from_host;
+ host_session.notopen = true;
+ host_session.direction = "outgoing";
+ host_session.outgoing = true;
+ host_session.hosts = {};
hosts[from_host].s2sout[to_host] = host_session;
- local conn_name = "s2sout"..tostring(host_session):match("[a-f0-9]*$");
- host_session.log = logger_init(conn_name);
return host_session;
end
@@ -50,6 +61,9 @@ local resting_session = { -- Resting, not dead
close = function (session)
session.log("debug", "Attempt to close already-closed session");
end;
+ reset_stream = function (session)
+ session.log("debug", "Attempt to reset stream of already-closed session");
+ end;
filter = function (type, data) return data; end; --luacheck: ignore 212/type
}; resting_session.__index = resting_session;
@@ -63,8 +77,8 @@ local function retire_session(session, reason)
session.destruction_reason = reason;
- function session.send(data) log("debug", "Discarding data sent to resting session: %s", tostring(data)); end
- function session.data(data) log("debug", "Discarding data received from resting session: %s", tostring(data)); end
+ function session.send(data) log("debug", "Discarding data sent to resting session: %s", data); end
+ function session.data(data) log("debug", "Discarding data received from resting session: %s", data); end
session.thread = { run = function (_, data) return session.data(data) end };
session.sends2s = session.send;
return setmetatable(session, resting_session);
@@ -72,14 +86,16 @@ end
local function destroy_session(session, reason)
if session.destroyed then return; end
- (session.log or log)("debug", "Destroying "..tostring(session.direction)
- .." session "..tostring(session.from_host).."->"..tostring(session.to_host)
- ..(reason and (": "..reason) or ""));
+ local log = session.log or log;
+ log("debug", "Destroying %s session %s->%s%s%s", session.direction, session.from_host, session.to_host, reason and ": " or "", reason or "");
if session.direction == "outgoing" then
hosts[session.from_host].s2sout[session.to_host] = nil;
session:bounce_sendq(reason);
elseif session.direction == "incoming" then
+ if session.outgoing then
+ hosts[session.to_host].s2sout[session.from_host] = nil;
+ end
incoming_s2s[session] = nil;
end
diff --git a/core/sessionmanager.lua b/core/sessionmanager.lua
index 2843001a..6c005fcd 100644
--- a/core/sessionmanager.lua
+++ b/core/sessionmanager.lua
@@ -21,6 +21,7 @@ local config_get = require "core.configmanager".get;
local resourceprep = require "util.encodings".stringprep.resourceprep;
local nodeprep = require "util.encodings".stringprep.nodeprep;
local generate_identifier = require "util.id".short;
+local sessionlib = require "util.session";
local initialize_filters = require "util.filters".initialize;
local gettime = require "socket".gettime;
@@ -29,23 +30,34 @@ local _ENV = nil;
-- luacheck: std none
local function new_session(conn)
- local session = { conn = conn, type = "c2s_unauthed", conntime = gettime() };
+ local session = sessionlib.new("c2s");
+ sessionlib.set_id(session);
+ sessionlib.set_logger(session);
+ sessionlib.set_conn(session, conn);
+
+ session.conntime = gettime();
local filter = initialize_filters(session);
local w = conn.write;
+
+ function session.rawsend(t)
+ t = filter("bytes/out", tostring(t));
+ if t then
+ local ret, err = w(conn, t);
+ if not ret then
+ session.log("debug", "Error writing to connection: %s", err);
+ return false, err;
+ end
+ end
+ return true;
+ end
+
session.send = function (t)
session.log("debug", "Sending[%s]: %s", session.type, t.top_tag and t:top_tag() or t:match("^[^>]*>?"));
if t.name then
t = filter("stanzas/out", t);
end
if t then
- t = filter("bytes/out", tostring(t));
- if t then
- local ret, err = w(conn, t);
- if not ret then
- session.log("debug", "Error writing to connection: %s", tostring(err));
- return false, err;
- end
- end
+ return session.rawsend(t);
end
return true;
end
@@ -73,8 +85,8 @@ local function retire_session(session)
end
end
- function session.send(data) log("debug", "Discarding data sent to resting session: %s", tostring(data)); return false; end
- function session.data(data) log("debug", "Discarding data received from resting session: %s", tostring(data)); end
+ function session.send(data) log("debug", "Discarding data sent to resting session: %s", data); return false; end
+ function session.data(data) log("debug", "Discarding data received from resting session: %s", data); end
session.thread = { run = function (_, data) return session.data(data) end };
return setmetatable(session, resting_session);
end
@@ -117,7 +129,7 @@ local function make_authenticated(session, username)
if session.type == "c2s_unauthed" then
session.type = "c2s_unbound";
end
- session.log("info", "Authenticated as %s@%s", username or "(unknown)", session.host or "(unknown)");
+ session.log("info", "Authenticated as %s@%s", username, session.host or "(unknown)");
return true;
end
@@ -138,7 +150,7 @@ local function bind_resource(session, resource)
resource = event_payload.resource;
end
- resource = resourceprep(resource);
+ resource = resourceprep(resource or "", true);
resource = resource ~= "" and resource or generate_identifier();
--FIXME: Randomly-generated resources must be unique per-user, and never conflict with existing
diff --git a/core/stanza_router.lua b/core/stanza_router.lua
index f5a34f59..a74f3b6f 100644
--- a/core/stanza_router.lua
+++ b/core/stanza_router.lua
@@ -111,8 +111,8 @@ function core_process_stanza(origin, stanza)
stanza.attr.from = from;
end
- if (origin.type == "s2sin" or origin.type == "c2s" or origin.type == "component") and xmlns == nil then
- if origin.type == "s2sin" and not origin.dummy then
+ if (origin.type == "s2sin" or origin.type == "s2sout" or origin.type == "c2s" or origin.type == "component") and xmlns == nil then
+ if (origin.type == "s2sin" or origin.type == "s2sout") and not origin.dummy then
local host_status = origin.hosts[from_host];
if not host_status or not host_status.authed then -- remote server trying to impersonate some other server?
log("warn", "Received a stanza claiming to be from %s, over a stream authed for %s!", from_host, origin.from_host);
@@ -199,7 +199,7 @@ function core_route_stanza(origin, stanza)
else
local host_session = hosts[from_host];
if not host_session then
- log("error", "No hosts[from_host] (please report): %s", tostring(stanza));
+ log("error", "No hosts[from_host] (please report): %s", stanza);
else
local xmlns = stanza.attr.xmlns;
stanza.attr.xmlns = nil;
diff --git a/core/statsmanager.lua b/core/statsmanager.lua
index 237b1dd5..50798ad0 100644
--- a/core/statsmanager.lua
+++ b/core/statsmanager.lua
@@ -97,6 +97,7 @@ if stats then
end
timer.add_task(stats_interval, collect);
prosody.events.add_handler("server-started", function () collect() end, -1);
+ prosody.events.add_handler("server-stopped", function () collect() end, -1);
else
log("debug", "Statistics enabled using %s provider, collection is disabled", stats_provider_name);
end