diff options
83 files changed, 2384 insertions, 1525 deletions
@@ -1,4 +1,10 @@ -== Core development team == -Matthew Wild (matthew.wild AT heavy-horse.co.uk) -Waqas Hussain (waqas20 AT gmail.com) +The Prosody project is open to contributions (see HACKERS file), but is +maintained daily by: + + - Matthew Wild (mail: matthew [at] prosody.im) + - Waqas Hussain (mail: waqas [at] prosody.im) + - Kim Alvefur (mail: zash [at] prosody.im) + +You can reach us collectively by email: developers [at] prosody.im +or in realtime in the Prosody chatroom: prosody@conference.prosody.im @@ -1,9 +1,11 @@ -The easiest way to install dependencies is using the luarocks tool. -Rocks: -luaexpat -luasocket -luafilesystem +For full information on our dependencies, version requirements, and +where to find them, see http://prosody.im/doc/depends + +If you have luarocks available on your platform, install the following: + + - luaexpat + - luasocket + - luafilesystem + - luasec -Non-rocks: -LuaSec for SSL connections @@ -1,18 +1,18 @@ -(This file was created from -http://prosody.im/doc/installing_from_source on 2012-05-12) +(This file was created from +http://prosody.im/doc/installing_from_source on 2013-03-31) -===== Building ===== +====== Installing from source ====== ==== Dependencies ==== There are a couple of libraries which Prosody needs installed before you can build it. These are: - * lua5.1: The interpreter + * lua5.1: The Lua 5.1 interpreter * liblua5.1: Lua 5.1 library * libssl 0.9.8: OpenSSL * libidn11: GNU libidn library, version 1.1 -These can be installed on Debian/Ubuntu with the packages: -lua5.1 liblua5.1-dev libidn11-dev libssl-dev +These can be installed on Debian/Ubuntu with the packages: lua5.1 +liblua5.1-dev libidn11-dev libssl-dev On Mandriva try: urpmi lua liblua-devel libidn-devel libopenssl-devel @@ -33,7 +33,8 @@ accepts. You can load a preset using: ./configure --ostype=PRESET -Where PRESET can currently be one of: debian, macosx or freebsd +Where PRESET can currently be one of: 'debian', 'macosx' or (in 0.8 +and later) 'freebsd' ==== make ==== Once you have run configure successfully, then you can simply run: @@ -38,7 +38,7 @@ install: prosody.install prosodyctl.install prosody.cfg.lua.install util/encodin install -m644 certs/* $(CONFIG)/certs install -m644 man/prosodyctl.man $(MAN)/man1/prosodyctl.1 test -e $(CONFIG)/prosody.cfg.lua || install -m644 prosody.cfg.lua.install $(CONFIG)/prosody.cfg.lua - test -e prosody.version && install prosody.version $(SOURCE)/prosody.version || true + test -e prosody.version && install -m644 prosody.version $(SOURCE)/prosody.version || true $(MAKE) install -C util-src clean: @@ -1,10 +1,5 @@ -== 0.9 == -- IPv6 -- SASL EXTERNAL -- Roster providers -- Web interface - == 1.0 == +- Roster providers - Statistics - Clustering - World domination diff --git a/certs/openssl.cnf b/certs/openssl.cnf index db1640b9..091409c4 100644 --- a/certs/openssl.cnf +++ b/certs/openssl.cnf @@ -2,7 +2,7 @@ oid_section = new_oids [ new_oids ] -# RFC 3920 section 5.1.1 defines this OID +# RFC 6120 section 13.7.1.4. defines this OID xmppAddr = 1.3.6.1.5.5.7.8.5 # RFC 4985 defines this OID @@ -40,7 +40,7 @@ subjectAltName = @subject_alternative_name [ subject_alternative_name ] -# See http://tools.ietf.org/html/draft-ietf-xmpp-3920bis#section-13.7.1.2 for more info. +# See http://tools.ietf.org/html/rfc6120#section-13.7.1.2 for more info. DNS.0 = example.com otherName.0 = xmppAddr;FORMAT:UTF8,UTF8:example.com @@ -41,15 +41,17 @@ Configure Prosody prior to building. Default is "$LUA_SUFFIX" (lua$LUA_SUFFIX...) --with-lua=PREFIX Use Lua from given prefix. Default is $LUA_DIR +--runwith=BINARY What Lua binary to set as runtime environment. + Default is $RUNWITH --with-lua-include=DIR You can also specify Lua's includes dir. Default is \$LUA_DIR/include --with-lua-lib=DIR You can also specify Lua's libraries dir. Default is \$LUA_DIR/lib --with-idn=LIB The name of the IDN library to link with. Default is $IDN_LIB ---idn-library=(idn|icu) Select library to use for IDNA functionality. - idn: use GNU libidn (default) - icu: use ICU from IBM +--idn-library=(idn|icu) Select library to use for IDNA functionality. + idn: use GNU libidn (default) + icu: use ICU from IBM --with-ssl=LIB The name of the SSL to link with. Default is $OPENSSL_LIB --cflags=FLAGS Flags to pass to the compiler @@ -95,6 +97,7 @@ do if [ "$OSTYPE" = "debian" ] then LUA_SUFFIX="5.1"; LUA_SUFFIX_SET=yes + RUNWITH="lua5.1" LUA_INCDIR=/usr/include/lua5.1; LUA_INCDIR_SET=yes CFLAGS="$CFLAGS -D_GNU_SOURCE" diff --git a/core/certmanager.lua b/core/certmanager.lua index 8607e618..b91f7110 100644 --- a/core/certmanager.lua +++ b/core/certmanager.lua @@ -27,7 +27,7 @@ end module "certmanager" -- Global SSL options if not overridden per-host -local default_ssl_config = configmanager.get("*", "core", "ssl"); +local default_ssl_config = configmanager.get("*", "ssl"); local default_capath = "/etc/ssl/certs"; local default_verify = (ssl and ssl.x509 and { "peer", "client_once", }) or "none"; local default_options = { "no_sslv2", luasec_has_noticket and "no_ticket" or nil }; @@ -100,7 +100,7 @@ function create_context(host, mode, user_ssl_config) end function reload_ssl_config() - default_ssl_config = configmanager.get("*", "core", "ssl"); + default_ssl_config = configmanager.get("*", "ssl"); end prosody.events.add_handler("config-reloaded", reload_ssl_config); diff --git a/core/configmanager.lua b/core/configmanager.lua index 51b9f5fe..9720f48a 100644 --- a/core/configmanager.lua +++ b/core/configmanager.lua @@ -7,8 +7,8 @@ -- local _G = _G; -local setmetatable, loadfile, pcall, rawget, rawset, io, error, dofile, type, pairs, table = - setmetatable, loadfile, pcall, rawget, rawset, io, error, dofile, type, pairs, table; +local setmetatable, rawget, rawset, io, error, dofile, type, pairs, table = + setmetatable, rawget, rawset, io, error, dofile, type, pairs, table; local format, math_max = string.format, math.max; local fire_event = prosody and prosody.events.fire_event or function () end; @@ -22,67 +22,52 @@ module "configmanager" local parsers = {}; local config_mt = { __index = function (t, k) return rawget(t, "*"); end}; -local config = setmetatable({ ["*"] = { core = {} } }, config_mt); +local config = setmetatable({ ["*"] = { } }, config_mt); -- When host not found, use global -local host_mt = { }; - --- When key not found in section, check key in global's section -function section_mt(section_name) - return { __index = function (t, k) - local section = rawget(config["*"], section_name); - if not section then return nil; end - return section[k]; - end - }; -end +local host_mt = { __index = function(_, k) return config["*"][k] end } function getconfig() return config; end -function get(host, section, key) - if not key then - section, key = "core", section; - end - local sec = config[host][section]; - if sec then - return sec[key]; +function get(host, key, _oldkey) + if key == "core" then + key = _oldkey; -- COMPAT with code that still uses "core" end - return nil; + return config[host][key]; end -function _M.rawget(host, section, key) +function _M.rawget(host, key, _oldkey) + if key == "core" then + key = _oldkey; -- COMPAT with code that still uses "core" + end local hostconfig = rawget(config, host); if hostconfig then - local sectionconfig = rawget(hostconfig, section); - if sectionconfig then - return rawget(sectionconfig, key); - end + return rawget(hostconfig, key); end end -local function set(config, host, section, key, value) - if host and section and key then +local function set(config, host, key, value) + if host and key then local hostconfig = rawget(config, host); if not hostconfig then hostconfig = rawset(config, host, setmetatable({}, host_mt))[host]; end - if not rawget(hostconfig, section) then - hostconfig[section] = setmetatable({}, section_mt(section)); - end - hostconfig[section][key] = value; + hostconfig[key] = value; return true; end return false; end -function _M.set(host, section, key, value) - return set(config, host, section, key, value); +function _M.set(host, key, value, _oldvalue) + if key == "core" then + key, value = value, _oldvalue; --COMPAT with code that still uses "core" + end + return set(config, host, key, value); end -- Helper function to resolve relative paths (needed by config) do - local rel_path_start = ".."..path_sep; function resolve_relative_path(parent_path, path) if path then -- Some normalization @@ -122,7 +107,7 @@ function load(filename, format) if parsers[format] and parsers[format].load then local f, err = io.open(filename); if f then - local new_config = setmetatable({ ["*"] = { core = {} } }, config_mt); + local new_config = setmetatable({ ["*"] = { } }, config_mt); local ok, err = parsers[format].load(f:read("*a"), filename, new_config); f:close(); if ok then @@ -166,7 +151,7 @@ end -- Built-in Lua parser do local pcall, setmetatable = _G.pcall, _G.setmetatable; - local rawget, tostring = _G.rawget, _G.tostring; + local rawget = _G.rawget; parsers.lua = {}; function parsers.lua.load(data, config_file, config) local env; @@ -176,53 +161,50 @@ do Component = true, component = true, Include = true, include = true, RunScript = true }, { __index = function (t, k) - return rawget(_G, k) or - function (settings_table) - config[__currenthost or "*"][k] = settings_table; - end; + return rawget(_G, k); end, __newindex = function (t, k, v) - set(config, env.__currenthost or "*", "core", k, v); + set(config, env.__currenthost or "*", k, v); end }); rawset(env, "__currenthost", "*") -- Default is global function env.VirtualHost(name) - if rawget(config, name) and rawget(config[name].core, "component_module") then + if rawget(config, name) and rawget(config[name], "component_module") then error(format("Host %q clashes with previously defined %s Component %q, for services use a sub-domain like conference.%s", - name, config[name].core.component_module:gsub("^%a+$", { component = "external", muc = "MUC"}), name, name), 0); + name, config[name].component_module:gsub("^%a+$", { component = "external", muc = "MUC"}), name, name), 0); end rawset(env, "__currenthost", name); -- Needs at least one setting to logically exist :) - set(config, name or "*", "core", "defined", true); + set(config, name or "*", "defined", true); return function (config_options) rawset(env, "__currenthost", "*"); -- Return to global scope for option_name, option_value in pairs(config_options) do - set(config, name or "*", "core", option_name, option_value); + set(config, name or "*", option_name, option_value); end end; end env.Host, env.host = env.VirtualHost, env.VirtualHost; function env.Component(name) - if rawget(config, name) and rawget(config[name].core, "defined") and not rawget(config[name].core, "component_module") then + if rawget(config, name) and rawget(config[name], "defined") and not rawget(config[name], "component_module") then error(format("Component %q clashes with previously defined Host %q, for services use a sub-domain like conference.%s", name, name, name), 0); end - set(config, name, "core", "component_module", "component"); + set(config, name, "component_module", "component"); -- Don't load the global modules by default - set(config, name, "core", "load_global_modules", false); + set(config, name, "load_global_modules", false); rawset(env, "__currenthost", name); local function handle_config_options(config_options) rawset(env, "__currenthost", "*"); -- Return to global scope for option_name, option_value in pairs(config_options) do - set(config, name or "*", "core", option_name, option_value); + set(config, name or "*", option_name, option_value); end end return function (module) if type(module) == "string" then - set(config, name, "core", "component_module", module); + set(config, name, "component_module", module); return handle_config_options; end return handle_config_options(module); @@ -230,7 +212,7 @@ do end env.component = env.Component; - function env.Include(file, wildcard) + function env.Include(file) if file:match("[*?]") then local path_pos, glob = file:match("()([^"..path_sep.."]+)$"); local path = file:sub(1, math_max(path_pos-2,0)); diff --git a/core/hostmanager.lua b/core/hostmanager.lua index cee4a1d6..06ba72a1 100644 --- a/core/hostmanager.lua +++ b/core/hostmanager.lua @@ -17,14 +17,15 @@ local uuid_gen = require "util.uuid".generate; local log = require "util.logger".init("hostmanager"); -local hosts = hosts; +local hosts = prosody.hosts; local prosody_events = prosody.events; if not _G.prosody.incoming_s2s then require "core.s2smanager"; end local incoming_s2s = _G.prosody.incoming_s2s; +local core_route_stanza = _G.prosody.core_route_stanza; -local pairs, select = pairs, select; +local pairs, select, rawget = pairs, select, rawget; local tostring, type = tostring, type; module "hostmanager" @@ -36,8 +37,8 @@ local function load_enabled_hosts(config) local activated_any_host; for host, host_config in pairs(defined_hosts) do - if host ~= "*" and host_config.core.enabled ~= false then - if not host_config.core.component_module then + if host ~= "*" and host_config.enabled ~= false then + if not host_config.component_module then activated_any_host = true; end activate(host, host_config); @@ -66,18 +67,18 @@ local function host_send(stanza) end function activate(host, host_config) - if hosts[host] then return nil, "The host "..host.." is already activated"; end + if rawget(hosts, host) then return nil, "The host "..host.." is already activated"; end host_config = host_config or configmanager.getconfig()[host]; if not host_config then return nil, "Couldn't find the host "..tostring(host).." defined in the current config"; end local host_session = { host = host; s2sout = {}; events = events_new(); - dialback_secret = configmanager.get(host, "core", "dialback_secret") or uuid_gen(); + dialback_secret = configmanager.get(host, "dialback_secret") or uuid_gen(); send = host_send; modules = {}; }; - if not host_config.core.component_module then -- host + if not host_config.component_module then -- host host_session.type = "local"; host_session.sessions = {}; else -- component @@ -85,9 +86,9 @@ function activate(host, host_config) end hosts[host] = host_session; if not host:match("[@/]") then - disco_items:set(host:match("%.(.*)") or "*", host, host_config.core.name or true); + disco_items:set(host:match("%.(.*)") or "*", host, host_config.name or true); end - for option_name in pairs(host_config.core) do + for option_name in pairs(host_config) do if option_name:match("_ports$") or option_name:match("_interface$") then log("warn", "%s: Option '%s' has no effect for virtual hosts - put it in the server-wide section instead", host, option_name); end diff --git a/core/loggingmanager.lua b/core/loggingmanager.lua index c3fc83e4..c69dede8 100644 --- a/core/loggingmanager.lua +++ b/core/loggingmanager.lua @@ -146,7 +146,7 @@ function reload_logging() logger.reset(); - local debug_mode = config.get("*", "core", "debug"); + local debug_mode = config.get("*", "debug"); default_logging = { { to = "console" , levels = { min = (debug_mode and "debug") or "info" } } }; default_file_logging = { @@ -154,7 +154,7 @@ function reload_logging() }; default_timestamp = "%b %d %H:%M:%S"; - logging_config = config.get("*", "core", "log") or default_logging; + logging_config = config.get("*", "log") or default_logging; for name, sink_maker in pairs(old_sink_types) do diff --git a/core/moduleapi.lua b/core/moduleapi.lua index 20898fcf..ed75669b 100644 --- a/core/moduleapi.lua +++ b/core/moduleapi.lua @@ -21,7 +21,10 @@ local tonumber, tostring = tonumber, tostring; local prosody = prosody; local hosts = prosody.hosts; -local core_post_stanza = prosody.core_post_stanza; + +-- FIXME: This assert() is to try and catch an obscure bug (2013-04-05) +local core_post_stanza = assert(prosody.core_post_stanza, + "prosody.core_post_stanza is nil, please report this as a bug"); -- Registry of shared module data local shared_data = setmetatable({}, { __mode = "v" }); @@ -62,6 +65,20 @@ end function api:add_extension(data) self:add_item("extension", data); end +function api:has_feature(xmlns) + for _, feature in ipairs(self:get_host_items("feature")) do + if feature == xmlns then return true; end + end + return false; +end +function api:has_identity(category, type, name) + for _, id in ipairs(self:get_host_items("identity")) do + if id.category == category and id.type == type and id.name == name then + return true; + end + end + return false; +end function api:fire_event(...) return (hosts[self.host] or prosody).events.fire_event(...); @@ -167,12 +184,9 @@ function api:shared(...) end function api:get_option(name, default_value) - local value = config.get(self.host, self.name, name); + local value = config.get(self.host, name); if value == nil then - value = config.get(self.host, "core", name); - if value == nil then - value = default_value; - end + value = default_value; end return value; end @@ -256,6 +270,22 @@ function api:get_option_set(name, ...) return set.new(value); end +function api:get_option_inherited_set(name, ...) + local value = self:get_option_set(name, ...); + local global_value = self:context("*"):get_option_set(name, ...); + if not value then + return global_value; + elseif not global_value then + return value; + end + value:include(global_value); + return value; +end + +function api:context(host) + return setmetatable({host=host or "*"}, {__index=self,__newindex=self}); +end + function api:add_item(key, value) self.items = self.items or {}; self.items[key] = self.items[key] or {}; @@ -274,23 +304,7 @@ function api:remove_item(key, value) end function api:get_host_items(key) - local result = {}; - for mod_name, module in pairs(modulemanager.get_modules(self.host)) do - module = module.module; - if module.items then - for _, item in ipairs(module.items[key] or NULL) do - t_insert(result, item); - end - end - end - for mod_name, module in pairs(modulemanager.get_modules("*")) do - module = module.module; - if module.items then - for _, item in ipairs(module.items[key] or NULL) do - t_insert(result, item); - end - end - end + local result = modulemanager.get_items(key, self.host) or {}; return result; end @@ -305,7 +319,13 @@ function api:handle_items(type, added_cb, removed_cb, existing) end function api:provides(name, item) - if not item then item = self.environment; end + -- if not item then item = setmetatable({}, { __index = function(t,k) return rawget(self.environment, k); end }); end + if not item then + item = {} + for k,v in pairs(self.environment) do + if k ~= "module" then item[k] = v; end + end + end if not item.name then local item_name = self.name; -- Strip a provider prefix to find the item name @@ -315,6 +335,7 @@ function api:provides(name, item) end item.name = item_name; end + item._provided_by = self.name; self:add_item(name.."-provider", item); end @@ -339,4 +360,8 @@ function api:load_resource(path, mode) return io.open(path, mode); end +function api:open_store(name, type) + return storagemanager.open(self.host, name or self.name, type); +end + return api; diff --git a/core/modulemanager.lua b/core/modulemanager.lua index 4ba2c27e..535c227b 100644 --- a/core/modulemanager.lua +++ b/core/modulemanager.lua @@ -19,7 +19,7 @@ local prosody = prosody; local pcall, xpcall = pcall, xpcall; local setmetatable, rawget = setmetatable, rawget; -local pairs, type, tostring = pairs, type, tostring; +local ipairs, pairs, type, tostring, t_insert = ipairs, pairs, type, tostring, table.insert; local debug_traceback = debug.traceback; local unpack, select = unpack, select; @@ -44,12 +44,12 @@ local modulemap = { ["*"] = {} }; -- Load modules when a host is activated function load_modules_for_host(host) - local component = config.get(host, "core", "component_module"); + local component = config.get(host, "component_module"); - local global_modules_enabled = config.get("*", "core", "modules_enabled"); - local global_modules_disabled = config.get("*", "core", "modules_disabled"); - local host_modules_enabled = config.get(host, "core", "modules_enabled"); - local host_modules_disabled = config.get(host, "core", "modules_disabled"); + local global_modules_enabled = config.get("*", "modules_enabled"); + local global_modules_disabled = config.get("*", "modules_disabled"); + local host_modules_enabled = config.get(host, "modules_enabled"); + local host_modules_disabled = config.get(host, "modules_disabled"); if host_modules_enabled == global_modules_enabled then host_modules_enabled = nil; end if host_modules_disabled == global_modules_disabled then host_modules_disabled = nil; end @@ -218,7 +218,7 @@ local function do_reload_module(host, name) saved = ret; else log("warn", "Error saving module '%s:%s' state: %s", host, name, ret); - if not config.get(host, "core", "force_module_reload") then + if not config.get(host, "force_module_reload") then log("warn", "Aborting reload due to error, set force_module_reload to ignore this"); return nil, "save-state-failed"; else @@ -278,6 +278,23 @@ function get_module(host, name) return modulemap[host] and modulemap[host][name]; end +function get_items(key, host) + local result = {}; + local modules = modulemap[host]; + if not key or not host or not modules then return nil; end + + for _, module in pairs(modules) do + local mod = module.module; + if mod.items and mod.items[key] then + for _, value in ipairs(mod.items[key]) do + t_insert(result, value); + end + end + end + + return result; +end + function get_modules(host) return modulemap[host]; end diff --git a/core/portmanager.lua b/core/portmanager.lua index b02ba53b..7a247452 100644 --- a/core/portmanager.lua +++ b/core/portmanager.lua @@ -1,6 +1,7 @@ local config = require "core.configmanager"; local certmanager = require "core.certmanager"; local server = require "net.server"; +local socket = require "socket"; local log = require "util.logger".init("portmanager"); local multitable = require "util.multitable"; @@ -8,7 +9,7 @@ local set = require "util.set"; local table = table; local setmetatable, rawset, rawget = setmetatable, rawset, rawget; -local type, tonumber, ipairs = type, tonumber, ipairs; +local type, tonumber, tostring, ipairs, pairs = type, tonumber, tostring, ipairs, pairs; local prosody = prosody; local fire_event = prosody.events.fire_event; @@ -17,9 +18,13 @@ module "portmanager"; --- Config -local default_interfaces = { "*" }; -local default_local_interfaces = { "127.0.0.1" }; -if config.get("*", "use_ipv6") then +local default_interfaces = { }; +local default_local_interfaces = { }; +if config.get("*", "use_ipv4") ~= false then + table.insert(default_interfaces, "*"); + table.insert(default_local_interfaces, "127.0.0.1"); +end +if socket.tcp6 and config.get("*", "use_ipv6") ~= false then table.insert(default_interfaces, "::"); table.insert(default_local_interfaces, "::1"); end @@ -65,6 +70,16 @@ prosody.events.add_handler("item-removed/net-provider", function (event) unregister_service(item.name, item); end); +local function duplicate_ssl_config(ssl_config) + local ssl_config = type(ssl_config) == "table" and ssl_config or {}; + + local _config = {}; + for k, v in pairs(ssl_config) do + _config[k] = v; + end + return _config; +end + --- Public API function activate(service_name) @@ -97,31 +112,50 @@ function activate(service_name) bind_ports = set.new(type(bind_ports) ~= "table" and { bind_ports } or bind_ports ); local mode, ssl = listener.default_mode or "*a"; + local hooked_ports = {}; for interface in bind_interfaces do for port in bind_ports do - port = tonumber(port); - if #active_services:search(nil, interface, port) > 0 then + local port_number = tonumber(port); + if not port_number then + log("error", "Invalid port number specified for service '%s': %s", service_info.name, tostring(port)); + elseif #active_services:search(nil, interface, port_number) > 0 then log("error", "Multiple services configured to listen on the same port ([%s]:%d): %s, %s", interface, port, active_services:search(nil, interface, port)[1][1].service.name or "<unnamed>", service_name or "<unnamed>"); else local err; -- Create SSL context for this service/port if service_info.encryption == "ssl" then - local ssl_config = config.get("*", config_prefix.."ssl"); - ssl, err = certmanager.create_context(service_info.name.." port "..port, "server", ssl_config and (ssl_config[port] - or (ssl_config.certificate and ssl_config))); + local ssl_config = duplicate_ssl_config((config.get("*", config_prefix.."ssl") and config.get("*", config_prefix.."ssl")[interface]) + or (config.get("*", config_prefix.."ssl") and config.get("*", config_prefix.."ssl")[port]) + or config.get("*", config_prefix.."ssl") + or (config.get("*", "ssl") and config.get("*", "ssl")[interface]) + or (config.get("*", "ssl") and config.get("*", "ssl")[port]) + or config.get("*", "ssl")); + -- add default entries for, or override ssl configuration + if ssl_config and service_info.ssl_config then + for key, value in pairs(service_info.ssl_config) do + if not service_info.ssl_config_override and not ssl_config[key] then + ssl_config[key] = value; + elseif service_info.ssl_config_override then + ssl_config[key] = value; + end + end + end + + ssl, err = certmanager.create_context(service_info.name.." port "..port, "server", ssl_config); if not ssl then - log("error", "Error binding encrypted port for %s: %s", service_info.name, error_to_friendly_message(service_name, port, err) or "unknown error"); + log("error", "Error binding encrypted port for %s: %s", service_info.name, error_to_friendly_message(service_name, port_number, err) or "unknown error"); end end if not err then -- Start listening on interface+port - local handler, err = server.addserver(interface, port, listener, mode, ssl); + local handler, err = server.addserver(interface, port_number, listener, mode, ssl); if not handler then - log("error", "Failed to open server port %d on %s, %s", port, interface, error_to_friendly_message(service_name, port, err)); + log("error", "Failed to open server port %d on %s, %s", port_number, interface, error_to_friendly_message(service_name, port_number, err)); else - log("debug", "Added listening service %s to [%s]:%d", service_name, interface, port); - active_services:add(service_name, interface, port, { + table.insert(hooked_ports, "["..interface.."]:"..port_number); + log("debug", "Added listening service %s to [%s]:%d", service_name, interface, port_number); + active_services:add(service_name, interface, port_number, { server = handler; service = service_info; }); @@ -130,7 +164,7 @@ function activate(service_name) end end end - log("info", "Activated service '%s'", service_name); + log("info", "Activated service '%s' on %s", service_name, #hooked_ports == 0 and "no ports" or table.concat(hooked_ports, ", ")); return true; end diff --git a/core/rostermanager.lua b/core/rostermanager.lua index fdb890f9..5e06e3f7 100644 --- a/core/rostermanager.lua +++ b/core/rostermanager.lua @@ -11,16 +11,14 @@ local log = require "util.logger".init("rostermanager"); -local setmetatable = setmetatable; -local format = string.format; -local pcall = pcall; -local pairs, ipairs = pairs, ipairs; +local pairs = pairs; local tostring = tostring; local hosts = hosts; local bare_sessions = bare_sessions; local datamanager = require "util.datamanager" +local um_user_exists = require "core.usermanager".user_exists; local st = require "util.stanza"; module "rostermanager" @@ -108,6 +106,11 @@ function load_roster(username, host) end function save_roster(username, host, roster) + if not um_user_exists(username, host) then + log("debug", "not saving roster for %s@%s: the user doesn't exist", username, host); + return nil; + end + log("debug", "save_roster: saving roster for %s@%s", username, host); if not roster then roster = hosts[host] and hosts[host].sessions[username] and hosts[host].sessions[username].roster; diff --git a/core/s2smanager.lua b/core/s2smanager.lua index 6049e12e..06d3f2c9 100644 --- a/core/s2smanager.lua +++ b/core/s2smanager.lua @@ -8,39 +8,30 @@ -local hosts = hosts; -local tostring, pairs, ipairs, getmetatable, newproxy, setmetatable - = tostring, pairs, ipairs, getmetatable, newproxy, setmetatable; +local hosts = prosody.hosts; +local tostring, pairs, setmetatable + = tostring, pairs, setmetatable; -local fire_event = prosody.events.fire_event; local logger_init = require "util.logger".init; local log = logger_init("s2smanager"); -local config = require "core.configmanager"; - local prosody = _G.prosody; incoming_s2s = {}; prosody.incoming_s2s = incoming_s2s; local incoming_s2s = incoming_s2s; +local fire_event = prosody.events.fire_event; module "s2smanager" -local open_sessions = 0; - function new_incoming(conn) local session = { conn = conn, type = "s2sin_unauthed", direction = "incoming", hosts = {} }; - if true then - session.trace = newproxy(true); - getmetatable(session.trace).__gc = function () open_sessions = open_sessions - 1; end; - end - open_sessions = open_sessions + 1; session.log = logger_init("s2sin"..tostring(session):match("[a-f0-9]+$")); incoming_s2s[session] = true; return session; end -function new_outgoing(from_host, to_host, connect) +function new_outgoing(from_host, to_host) local host_session = { to_host = to_host, from_host = from_host, host = from_host, notopen = true, type = "s2sout_unauthed", direction = "outgoing" }; hosts[from_host].s2sout[to_host] = host_session; @@ -49,75 +40,6 @@ function new_outgoing(from_host, to_host, connect) return host_session; end -function make_authenticated(session, host) - if not session.secure then - local local_host = session.direction == "incoming" and session.to_host or session.from_host; - if config.get(local_host, "core", "s2s_require_encryption") then - session:close({ - condition = "policy-violation", - text = "Encrypted server-to-server communication is required but was not " - ..((session.direction == "outgoing" and "offered") or "used") - }); - end - end - if session.type == "s2sout_unauthed" then - session.type = "s2sout"; - elseif session.type == "s2sin_unauthed" then - session.type = "s2sin"; - if host then - if not session.hosts[host] then session.hosts[host] = {}; end - session.hosts[host].authed = true; - end - elseif session.type == "s2sin" and host then - if not session.hosts[host] then session.hosts[host] = {}; end - session.hosts[host].authed = true; - else - return false; - end - session.log("debug", "connection %s->%s is now authenticated for %s", session.from_host, session.to_host, host); - - mark_connected(session); - - return true; -end - --- Stream is authorised, and ready for normal stanzas -function mark_connected(session) - local sendq, send = session.sendq, session.sends2s; - - local from, to = session.from_host, session.to_host; - - session.log("info", "%s s2s connection %s->%s complete", session.direction, from, to); - - local event_data = { session = session }; - if session.type == "s2sout" then - prosody.events.fire_event("s2sout-established", event_data); - hosts[from].events.fire_event("s2sout-established", event_data); - else - local host_session = hosts[to]; - session.send = function(stanza) - return host_session.events.fire_event("route/remote", { from_host = to, to_host = from, stanza = stanza }); - end; - - prosody.events.fire_event("s2sin-established", event_data); - hosts[to].events.fire_event("s2sin-established", event_data); - end - - if session.direction == "outgoing" then - if sendq then - session.log("debug", "sending %d queued stanzas across new outgoing connection to %s", #sendq, session.to_host); - for i, data in ipairs(sendq) do - send(data[1]); - sendq[i] = nil; - end - session.sendq = nil; - end - - session.ip_hosts = nil; - session.srv_hosts = nil; - end -end - local resting_session = { -- Resting, not dead destroyed = true; type = "s2s_destroyed"; @@ -133,7 +55,7 @@ local resting_session = { -- Resting, not dead function retire_session(session, reason) local log = session.log or log; for k in pairs(session) do - if k ~= "trace" and k ~= "log" and k ~= "id" and k ~= "conn" then + if k ~= "log" and k ~= "id" and k ~= "conn" then session[k] = nil; end end @@ -158,12 +80,12 @@ function destroy_session(session, reason) local event_data = { session = session, reason = reason }; if session.type == "s2sout" then - prosody.events.fire_event("s2sout-destroyed", event_data); + fire_event("s2sout-destroyed", event_data); if hosts[session.from_host] then hosts[session.from_host].events.fire_event("s2sout-destroyed", event_data); end elseif session.type == "s2sin" then - prosody.events.fire_event("s2sin-destroyed", event_data); + fire_event("s2sin-destroyed", event_data); if hosts[session.to_host] then hosts[session.to_host].events.fire_event("s2sin-destroyed", event_data); end diff --git a/core/sessionmanager.lua b/core/sessionmanager.lua index 05b2d64b..98ead07f 100644 --- a/core/sessionmanager.lua +++ b/core/sessionmanager.lua @@ -24,22 +24,10 @@ local uuid_generate = require "util.uuid".generate; local initialize_filters = require "util.filters".initialize; local gettime = require "socket".gettime; -local newproxy = newproxy; -local getmetatable = getmetatable; - module "sessionmanager" -local open_sessions = 0; - function new_session(conn) local session = { conn = conn, type = "c2s_unauthed", conntime = gettime() }; - if true then - session.trace = newproxy(true); - getmetatable(session.trace).__gc = function () open_sessions = open_sessions - 1; end; - end - open_sessions = open_sessions + 1; - log("debug", "open sessions now: %d", open_sessions); - local filter = initialize_filters(session); local w = conn.write; session.send = function (t) @@ -72,7 +60,7 @@ local resting_session = { -- Resting, not dead function retire_session(session) local log = session.log or log; for k in pairs(session) do - if k ~= "trace" and k ~= "log" and k ~= "id" then + if k ~= "log" and k ~= "id" then session[k] = nil; end end @@ -140,7 +128,7 @@ function bind_resource(session, resource) local sessions = hosts[session.host].sessions[session.username].sessions; if sessions[resource] then -- Resource conflict - local policy = config_get(session.host, "core", "conflict_resolve"); + local policy = config_get(session.host, "conflict_resolve"); local increment; if policy == "random" then resource = uuid_generate(); diff --git a/core/storagemanager.lua b/core/storagemanager.lua index 36a671be..1c82af6d 100644 --- a/core/storagemanager.lua +++ b/core/storagemanager.lua @@ -86,7 +86,7 @@ function open(host, store, typ) if not ret then if err == "unsupported-store" then log("debug", "Storage driver %s does not support store %s (%s), falling back to null driver", - driver_name, store, typ); + driver_name, store, typ or "<nil>"); ret = null_storage_driver; err = nil; end diff --git a/core/usermanager.lua b/core/usermanager.lua index 417d7037..08343bee 100644 --- a/core/usermanager.lua +++ b/core/usermanager.lua @@ -42,8 +42,8 @@ function initialize_host(host) host_session.events.add_handler("item-added/auth-provider", function (event) local provider = event.item; - local auth_provider = config.get(host, "core", "authentication") or default_provider; - if config.get(host, "core", "anonymous_login") then + local auth_provider = config.get(host, "authentication") or default_provider; + if config.get(host, "anonymous_login") then log("error", "Deprecated config option 'anonymous_login'. Use authentication = 'anonymous' instead."); auth_provider = "anonymous"; end -- COMPAT 0.7 @@ -61,8 +61,8 @@ function initialize_host(host) end end); host_session.users = new_null_provider(); -- Start with the default usermanager provider - local auth_provider = config.get(host, "core", "authentication") or default_provider; - if config.get(host, "core", "anonymous_login") then auth_provider = "anonymous"; end -- COMPAT 0.7 + local auth_provider = config.get(host, "authentication") or default_provider; + if config.get(host, "anonymous_login") then auth_provider = "anonymous"; end -- COMPAT 0.7 if auth_provider ~= "null" then modulemanager.load(host, "auth_"..auth_provider); end @@ -116,8 +116,8 @@ function is_admin(jid, host) jid = jid_bare(jid); host = host or "*"; - local host_admins = config.get(host, "core", "admins"); - local global_admins = config.get("*", "core", "admins"); + local host_admins = config.get(host, "admins"); + local global_admins = config.get("*", "admins"); if host_admins and host_admins ~= global_admins then if type(host_admins) == "table" then diff --git a/net/dns.lua b/net/dns.lua index a134eceb..c9c51fe8 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -223,7 +223,7 @@ end function dns.random(...) -- - - - - - - - - - - - - - - - - - - dns.random - math.randomseed(math.floor(10000*socket.gettime())); + math.randomseed(math.floor(10000*socket.gettime()) % 0x100000000); dns.random = math.random; return dns.random(...); end diff --git a/net/http.lua b/net/http.lua index 273eee09..3b783a41 100644 --- a/net/http.lua +++ b/net/http.lua @@ -9,14 +9,17 @@ local socket = require "socket" local b64 = require "util.encodings".base64.encode; local url = require "socket.url" -local httpstream_new = require "util.httpstream".new; +local httpstream_new = require "net.http.parser".new; +local util_http = require "util.http"; + +local ssl_available = pcall(require, "ssl"); local server = require "net.server" local t_insert, t_concat = table.insert, table.concat; -local pairs, ipairs = pairs, ipairs; -local tonumber, tostring, xpcall, select, debug_traceback, char, format = - tonumber, tostring, xpcall, select, debug.traceback, string.char, string.format; +local pairs = pairs; +local tonumber, tostring, xpcall, select, traceback = + tonumber, tostring, xpcall, select, debug.traceback; local log = require "util.logger".init("http"); @@ -63,64 +66,29 @@ end function listener.ondisconnect(conn, err) local request = requests[conn]; if request and request.conn then - request:reader(nil); + request:reader(nil, err); end requests[conn] = nil; end -function urlencode(s) return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end)); end -function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); end - -local function _formencodepart(s) - return s and (s:gsub("%W", function (c) - if c ~= " " then - return format("%%%02x", c:byte()); - else - return "+"; - end - end)); -end - -function formencode(form) - local result = {}; - if form[1] then -- Array of ordered { name, value } - for _, field in ipairs(form) do - t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value)); - end - else -- Unordered map of name -> value - for name, value in pairs(form) do - t_insert(result, _formencodepart(name).."=".._formencodepart(value)); - end - end - return t_concat(result, "&"); -end - -function formdecode(s) - if not s:match("=") then return urldecode(s); end - local r = {}; - for k, v in s:gmatch("([^=&]*)=([^&]*)") do - k, v = k:gsub("%+", "%%20"), v:gsub("%+", "%%20"); - k, v = urldecode(k), urldecode(v); - t_insert(r, { name = k, value = v }); - r[k] = v; - end - return r; -end - -local function request_reader(request, data, startpos) +local function request_reader(request, data, err) if not request.parser then - if not data then return; end - local function success_cb(r) + local function error_cb(reason) if request.callback then - for k,v in pairs(r) do request[k] = v; end - request.callback(r.body, r.code, request, r); + request.callback(reason or "connection-closed", 0, request); request.callback = nil; end destroy_request(request); end - local function error_cb(r) + + if not data then + error_cb(err); + return; + end + + local function success_cb(r) if request.callback then - request.callback(r or "connection-closed", 0, request); + request.callback(r.body, r.code, r, request); request.callback = nil; end destroy_request(request); @@ -133,7 +101,7 @@ local function request_reader(request, data, startpos) request.parser:feed(data); end -local function handleerr(err) log("error", "Traceback[http]: %s: %s", tostring(err), debug_traceback()); end +local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); end function request(u, ex, callback) local req = url.parse(u); @@ -177,6 +145,9 @@ function request(u, ex, callback) req.method, req.headers, req.body = method, headers, body; local using_https = req.scheme == "https"; + if using_https and not ssl_available then + error("SSL not available, unable to contact https URL"); + end local port = tonumber(req.port) or (using_https and 443 or 80); -- Connect the socket, and wrap it with net.server @@ -188,7 +159,12 @@ function request(u, ex, callback) return nil, err; end - req.handler, req.conn = server.wrapclient(conn, req.host, port, listener, "*a", using_https and { mode = "client", protocol = "sslv23" }); + local sslctx = false; + if using_https then + sslctx = ex and ex.sslctx or { mode = "client", protocol = "sslv23", options = { "no_sslv2" } }; + end + + req.handler, req.conn = server.wrapclient(conn, req.host, port, listener, "*a", sslctx); req.write = function (...) return req.handler:write(...); end req.callback = function (content, code, request, response) log("debug", "Calling callback, status %s", code or "---"); return select(2, xpcall(function () return callback(content, code, request, response) end, handleerr)); end @@ -206,6 +182,10 @@ function destroy_request(request) end end -_M.urlencode = urlencode; +local urlencode, urldecode = util_http.urlencode, util_http.urldecode; +local formencode, formdecode = util_http.formencode, util_http.formdecode; + +_M.urlencode, _M.urldecode = urlencode, urldecode; +_M.formencode, _M.formdecode = formencode, formdecode; return _M; diff --git a/net/http/parser.lua b/net/http/parser.lua index 2545b5ac..f9e6cea0 100644 --- a/net/http/parser.lua +++ b/net/http/parser.lua @@ -1,8 +1,7 @@ - local tonumber = tonumber; local assert = assert; local url_parse = require "socket.url".parse; -local urldecode = require "net.http".urldecode; +local urldecode = require "util.http".urldecode; local function preprocess_path(path) path = urldecode((path:gsub("//+", "/"))); @@ -29,7 +28,7 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb) local client = true; if not parser_type or parser_type == "server" then client = false; else assert(parser_type == "client", "Invalid parser type"); end local buf = ""; - local chunked; + local chunked, chunk_size, chunk_start; local state = nil; local packet; local len; @@ -65,12 +64,12 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb) first_line = line; if client then httpversion, status_code, reason_phrase = line:match("^HTTP/(1%.[01]) (%d%d%d) (.*)$"); + status_code = tonumber(status_code); if not status_code then error = true; return error_cb("invalid-status-line"); end have_body = not ( (options_cb and options_cb().method == "HEAD") or (status_code == 204 or status_code == 304 or status_code == 301) or (status_code >= 100 and status_code < 200) ); - chunked = have_body and headers["transfer-encoding"] == "chunked"; else method, path, httpversion = line:match("^(%w+) (%S+) HTTP/(1%.[01])$"); if not method then error = true; return error_cb("invalid-status-line"); end @@ -78,6 +77,7 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb) end end if not first_line then error = true; return error_cb("invalid-status-line"); end + chunked = have_body and headers["transfer-encoding"] == "chunked"; len = tonumber(headers["content-length"]); -- TODO check for invalid len if client then -- FIXME handle '100 Continue' response (by skipping it) @@ -120,22 +120,30 @@ function httpstream.new(success_cb, error_cb, parser_type, options_cb) if state then -- read body if client then if chunked then - local index = buf:find("\r\n", nil, true); - if not index then return; end -- not enough data - local chunk_size = buf:match("^%x+"); - if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end - chunk_size = tonumber(chunk_size, 16); - index = index + 2; - if chunk_size == 0 then - state = nil; success_cb(packet); - elseif #buf - index + 1 >= chunk_size then -- we have a chunk - packet.body = packet.body..buf:sub(index, index + chunk_size - 1); - buf = buf:sub(index + chunk_size); + if not buf:find("\r\n", nil, true) then + return; + end -- not enough data + if not chunk_size then + chunk_size, chunk_start = buf:match("^(%x+)[^\r\n]*\r\n()"); + chunk_size = chunk_size and tonumber(chunk_size, 16); + if not chunk_size then error = true; return error_cb("invalid-chunk-size"); end + end + if chunk_size == 0 and buf:find("\r\n\r\n", chunk_start-2, true) then + state, chunk_size = nil, nil; + buf = buf:gsub("^.-\r\n\r\n", ""); -- This ensure extensions and trailers are stripped + success_cb(packet); + elseif #buf - chunk_start + 2 >= chunk_size then -- we have a chunk + packet.body = packet.body..buf:sub(chunk_start, chunk_start + (chunk_size-1)); + buf = buf:sub(chunk_start + chunk_size + 2); + chunk_size, chunk_start = nil, nil; + else -- Partial chunk remaining + break; end - error("trailers"); -- FIXME MUST read trailers elseif len and #buf >= len then packet.body, buf = buf:sub(1, len), buf:sub(len + 1); state = nil; success_cb(packet); + else + break; end elseif #buf >= len then packet.body, buf = buf:sub(1, len), buf:sub(len + 1); diff --git a/net/http/server.lua b/net/http/server.lua index 7cf25009..dec7da19 100644 --- a/net/http/server.lua +++ b/net/http/server.lua @@ -9,7 +9,7 @@ local pairs = pairs; local s_upper = string.upper; local setmetatable = setmetatable; local xpcall = xpcall; -local debug = debug; +local traceback = debug.traceback; local tostring = tostring; local codes = require "net.http.codes"; @@ -27,8 +27,11 @@ local function is_wildcard_match(wildcard_event, event) return wildcard_event:sub(1, -2) == event:sub(1, #wildcard_event-1); end +local recent_wildcard_events, max_cached_wildcard_events = {}, 10000; + local event_map = events._event_map; setmetatable(events._handlers, { + -- Called when firing an event that doesn't exist (but may match a wildcard handler) __index = function (handlers, curr_event) if is_wildcard_event(curr_event) then return; end -- Wildcard events cannot be fired -- Find all handlers that could match this event, sort them @@ -58,6 +61,12 @@ setmetatable(events._handlers, { handlers_array = false; end rawset(handlers, curr_event, handlers_array); + if not event_map[curr_event] then -- Only wildcard handlers match, if any + table.insert(recent_wildcard_events, curr_event); + if #recent_wildcard_events > max_cached_wildcard_events then + rawset(handlers, table.remove(recent_wildcard_events, 1), nil); + end + end return handlers_array; end; __newindex = function (handlers, curr_event, handlers_array) @@ -79,7 +88,7 @@ local _1, _2, _3; local function _handle_request() return handle_request(_1, _2, _3); end local last_err; -local function _traceback_handler(err) last_err = err; log("error", "Traceback[http]: %s: %s", tostring(err), debug.traceback()); end +local function _traceback_handler(err) last_err = err; log("error", "Traceback[httpserver]: %s", traceback(tostring(err), 2)); end events.add_handler("http-error", function (error) return "Error processing request: "..codes[error.code]..". Check your error log for more information."; end, -1); @@ -89,29 +98,30 @@ function listener.onconnect(conn) local pending = {}; local waiting = false; local function process_next() - --if waiting then log("debug", "can't process_next, waiting"); return; end - if sessions[conn] and #pending > 0 then + if waiting then log("debug", "can't process_next, waiting"); return; end + waiting = true; + while sessions[conn] and #pending > 0 do local request = t_remove(pending); --log("debug", "process_next: %s", request.path); - waiting = true; --handle_request(conn, request, process_next); _1, _2, _3 = conn, request, process_next; if not xpcall(_handle_request, _traceback_handler) then conn:write("HTTP/1.0 500 Internal Server Error\r\n\r\n"..events.fire_event("http-error", { code = 500, private_message = last_err })); conn:close(); end - else - --log("debug", "ready for more"); - waiting = false; end + --log("debug", "ready for more"); + waiting = false; end local function success_cb(request) --log("debug", "success_cb: %s", request.path); + if waiting then + log("error", "http connection handler is not reentrant: %s", request.path); + assert(false, "http connection handler is not reentrant"); + end request.secure = secure; t_insert(pending, request); - if not waiting then - process_next(); - end + process_next(); end local function error_cb(err) log("debug", "error_cb: %s", err or "<nil>"); @@ -158,7 +168,7 @@ function handle_request(conn, request, finish_cb) local conn_header = request.headers.connection; conn_header = conn_header and ","..conn_header:gsub("[ \t]", ""):lower().."," or "" local httpversion = request.httpversion - local persistent = conn_header:find(",keep-alive,", 1, true) + local persistent = conn_header:find(",Keep-Alive,", 1, true) or (httpversion == "1.1" and not conn_header:find(",close,", 1, true)); local response_conn_header; @@ -218,7 +228,13 @@ function handle_request(conn, request, finish_cb) body = result; elseif result_type == "table" then for k, v in pairs(result) do - response[k] = v; + if k ~= "headers" then + response[k] = v; + else + for header_name, header_value in pairs(v) do + response.headers[header_name] = header_value; + end + end end end response:send(body); @@ -277,6 +293,9 @@ end function _M.set_default_host(host) default_host = host; end +function _M.fire_event(event, ...) + return events.fire_event(event, ...); +end _M.listener = listener; _M.codes = codes; diff --git a/net/server.lua b/net/server.lua index 3cdbe551..375e7081 100644 --- a/net/server.lua +++ b/net/server.lua @@ -6,7 +6,7 @@ -- COPYING file in the source package for more information. -- -local use_luaevent = prosody and require "core.configmanager".get("*", "core", "use_libevent"); +local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent"); if use_luaevent then use_luaevent = pcall(require, "luaevent.core"); @@ -42,11 +42,16 @@ end if prosody then local config_get = require "core.configmanager".get; + local defaults = {}; + for k,v in pairs(server.cfg or server.getsettings()) do + defaults[k] = v; + end local function load_config() - local settings = config_get("*", "core", "network_settings") or {}; + local settings = config_get("*", "network_settings") or {}; if use_luaevent then local event_settings = { ACCEPT_DELAY = settings.event_accept_retry_interval; + ACCEPT_QUEUE = settings.tcp_backlog; CLEAR_DELAY = settings.event_clear_interval; CONNECT_TIMEOUT = settings.connect_timeout; DEBUG = settings.debug; @@ -59,11 +64,15 @@ if prosody then WRITE_TIMEOUT = settings.send_timeout; }; - for k, v in pairs(event_settings) do - server.cfg[k] = v; + for k,default in pairs(defaults) do + server.cfg[k] = event_settings[k] or default; end else - server.changesettings(settings); + local select_settings = {}; + for k,default in pairs(defaults) do + select_settings[k] = settings[k] or default; + end + server.changesettings(select_settings); end end load_config(); diff --git a/net/server_event.lua b/net/server_event.lua index 08926939..5eae95a9 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -23,6 +23,7 @@ local cfg = { HANDSHAKE_TIMEOUT = 60, -- timeout in seconds per handshake attempt MAX_READ_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes allowed to read from sockets MAX_SEND_LENGTH = 1024 * 1024 * 1024 * 1024, -- max bytes size of write buffer (for writing on sockets) + ACCEPT_QUEUE = 128, -- might influence the length of the pending sockets queue ACCEPT_DELAY = 10, -- seconds to wait until the next attempt of a full server to accept READ_TIMEOUT = 60 * 60 * 6, -- timeout in seconds for read data from socket WRITE_TIMEOUT = 180, -- timeout in seconds for write data on socket @@ -460,7 +461,6 @@ end local handleclient; do local string_sub = string.sub -- caching table lookups - local string_len = string.len local addevent = base.addevent local socket_gettime = socket.gettime function handleclient( client, ip, port, server, pattern, listener, sslctx ) -- creates an client interface diff --git a/net/server_select.lua b/net/server_select.lua index 0852d444..7eb330a8 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -10,11 +10,6 @@ local use = function( what ) return _G[ what ] end -local clean = function( tbl ) - for i, k in pairs( tbl ) do - tbl[ i ] = nil - end -end local log, table_concat = require ("util.logger").init("socket"), table.concat; local out_put = function (...) return log("debug", table_concat{...}); end @@ -47,7 +42,6 @@ local os_difftime = os.difftime local math_min = math.min local math_huge = math.huge local table_concat = table.concat -local string_len = string.len local string_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield @@ -107,6 +101,7 @@ local _readtraffic local _selecttimeout local _sleeptime +local _tcpbacklog local _starttime local _currenttime @@ -118,11 +113,10 @@ local _checkinterval local _sendtimeout local _readtimeout -local _cleanqueue - local _timer -local _maxclientsperserver +local _maxselectlen +local _maxfd local _maxsslhandshake @@ -146,6 +140,7 @@ _readtraffic = 0 _selecttimeout = 1 -- timeout of socket.select _sleeptime = 0 -- time to wait at the end of every loop +_tcpbacklog = 128 -- some kind of hint to the OS _maxsendlen = 51000 * 1024 -- max len of send buffer _maxreadlen = 25000 * 1024 -- max len of read buffer @@ -154,17 +149,21 @@ _checkinterval = 1200000 -- interval in secs to check idle clients _sendtimeout = 60000 -- allowed send idle time in secs _readtimeout = 6 * 60 * 60 -- allowed read idle time in secs -_cleanqueue = false -- clean bufferqueue after using - -_maxclientsperserver = 1000 +local is_windows = package.config:sub(1,1) == "\\" -- check the directory separator, to detemine whether this is Windows +_maxfd = luasocket._SETSIZE or (is_windows and math.huge) or 1024 -- max fd number, limit to 1024 by default to prevent glibc buffer overflow, but not on Windows +_maxselectlen = luasocket._SETSIZE or 1024 -- But this still applies on Windows _maxsslhandshake = 30 -- max handshake round-trips ----------------------------------// PRIVATE //-- -wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxconnections ) -- this function wraps a server +wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx ) -- this function wraps a server -- FIXME Make sure FD < _maxfd - maxconnections = maxconnections or _maxclientsperserver + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) + socket:close() + return nil, "fd-too-large" + end local connections = 0 @@ -201,20 +200,23 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco --mem_free( ) out_put "server.lua: closed server handler and removed sockets from list" end - handler.pause = function() + handler.pause = function( hard ) if not handler.paused then - socket:close( ) - _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) _readlistlen = removesocket( _readlist, socket, _readlistlen ) - _socketlist[ socket ] = nil - socket = nil; + if hard then + _socketlist[ socket ] = nil + socket:close( ) + socket = nil; + end handler.paused = true; end end - handler.resume = function() + handler.resume = function( ) if handler.paused then - socket = socket_bind( ip, serverport ); - socket:settimeout( 0 ) + if not socket then + socket = socket_bind( ip, serverport, _tcpbacklog ); + socket:settimeout( 0 ) + end _readlistlen = addsocket(_readlist, socket, _readlistlen) _socketlist[ socket ] = handler handler.paused = false; @@ -230,7 +232,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco return socket end handler.readbuffer = function( ) - if connections > maxconnections then + if _readlistlen >= _maxselectlen or _sendlistlen >= _maxselectlen then handler.pause( ) out_put( "server.lua: refused new client connection: server full" ) return false @@ -244,7 +246,7 @@ wrapserver = function( listeners, socket, ip, serverport, pattern, sslctx, maxco end connections = connections + 1 out_put( "server.lua: accepted new client connection from ", tostring(ip), ":", tostring(clientport), " to ", tostring(serverport)) - if dispatch then + if dispatch and not sslctx then -- SSL connections will notify onconnect when handshake completes return dispatch( handler ); end return; @@ -258,6 +260,12 @@ end wrapconnection = function( server, listeners, socket, ip, serverport, clientport, pattern, sslctx ) -- this function wraps a client to a handler object + if socket:getfd() >= _maxfd then + out_error("server.lua: Disallowed FD number: "..socket:getfd()) -- PROTIP: Switch to libevent + socket:close( ) -- Should we send some kind of error here? + server.pause( ) + return nil, nil, "fd-too-large" + end socket:settimeout( 0 ) --// local import of socket methods //-- @@ -335,9 +343,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.force_close = function ( self, err ) if bufferqueuelen ~= 0 then out_put("server.lua: discarding unwritten data for ", tostring(ip), ":", tostring(clientport)) - for i = bufferqueuelen, 1, -1 do - bufferqueue[i] = nil; - end bufferqueuelen = 0; end return self:close(err); @@ -391,7 +396,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return clientport end local write = function( self, data ) - bufferlen = bufferlen + string_len( data ) + bufferlen = bufferlen + #data if bufferlen > maxsendlen then _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle handler.write = idfalse -- dont write anymore @@ -473,7 +478,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local buffer, err, part = receive( socket, pattern ) -- receive buffer with "pattern" if not err or (err == "wantread" or err == "timeout") then -- received something local buffer = buffer or part or "" - local len = string_len( buffer ) + local len = #buffer if len > maxreadlen then handler:close( "receive buffer exceeded" ) return false @@ -499,7 +504,9 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport count = ( succ or byte or 0 ) * STAT_UNIT sendtraffic = sendtraffic + count _sendtraffic = _sendtraffic + count - _ = _cleanqueue and clean( bufferqueue ) + for i = bufferqueuelen,1,-1 do + bufferqueue[ i ] = nil + end --out_put( "server.lua: sended '", buffer, "', bytes: ", tostring(succ), ", error: ", tostring(err), ", part: ", tostring(byte), ", to: ", tostring(ip), ":", tostring(clientport) ) else succ, err, count = false, "unexpected close", 0; @@ -568,7 +575,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport end out_put( "server.lua: ssl handshake error: ", tostring(err or "handshake too long") ) _ = handler and handler:force_close("ssl handshake failed") - return false, err -- handshake failed + return false, err -- handshake failed end ) end @@ -612,7 +619,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport handler.readbuffer = handshake handler.sendbuffer = handshake - return handshake( socket ) -- do handshake + return handshake( socket ) -- do handshake end end @@ -628,10 +635,10 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport if sslctx and luasec then out_put "server.lua: auto-starting ssl negotiation..." handler.autostart_ssl = true; - local ok, err = handler:starttls(sslctx); - if ok == false then - return nil, nil, err - end + local ok, err = handler:starttls(sslctx); + if ok == false then + return nil, nil, err + end end return handler, socket @@ -716,12 +723,12 @@ addserver = function( addr, port, listeners, pattern, sslctx ) -- this function return nil, err end addr = addr or "*" - local server, err = socket_bind( addr, port ) + local server, err = socket_bind( addr, port, _tcpbacklog ) if err then out_error( "server.lua, [", addr, "]:", port, ": ", err ) return nil, err end - local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx, _maxclientsperserver ) -- wrap new server socket + local handler, err = wrapserver( listeners, server, addr, port, pattern, sslctx ) -- wrap new server socket if not handler then server:close( ) return nil, err @@ -765,7 +772,19 @@ closeall = function( ) end getsettings = function( ) - return _selecttimeout, _sleeptime, _maxsendlen, _maxreadlen, _checkinterval, _sendtimeout, _readtimeout, _cleanqueue, _maxclientsperserver, _maxsslhandshake + return { + select_timeout = _selecttimeout; + select_sleep_time = _sleeptime; + tcp_backlog = _tcpbacklog; + max_send_buffer_size = _maxsendlen; + max_receive_buffer_size = _maxreadlen; + select_idle_check_interval = _checkinterval; + send_timeout = _sendtimeout; + read_timeout = _readtimeout; + max_connections = _maxselectlen; + max_ssl_handshake_roundtrips = _maxsslhandshake; + highest_allowed_fd = _maxfd; + } end changesettings = function( new ) @@ -777,11 +796,12 @@ changesettings = function( new ) _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval + _tcpbacklog = tonumber( new.tcp_backlog ) or _tcpbacklog _sendtimeout = tonumber( new.send_timeout ) or _sendtimeout _readtimeout = tonumber( new.read_timeout ) or _readtimeout - _cleanqueue = new.select_clean_queue - _maxclientsperserver = new.max_connections or _maxclientsperserver + _maxselectlen = new.max_connections or _maxselectlen _maxsslhandshake = new.max_ssl_handshake_roundtrips or _maxsslhandshake + _maxfd = new.highest_allowed_fd or _maxfd return true end @@ -831,9 +851,31 @@ loop = function(once) -- this is the main loop of the program for handler, err in pairs( _closelist ) do handler.disconnect( )( handler, err ) handler:force_close() -- forced disconnect + _closelist[ handler ] = nil; end - clean( _closelist ) _currenttime = luasocket_gettime( ) + + -- Check for socket timeouts + local difftime = os_difftime( _currenttime - _starttime ) + if difftime > _checkinterval then + _starttime = _currenttime + for handler, timestamp in pairs( _writetimes ) do + if os_difftime( _currenttime - timestamp ) > _sendtimeout then + --_writetimes[ handler ] = nil + handler.disconnect( )( handler, "send timeout" ) + handler:force_close() -- forced disconnect + end + end + for handler, timestamp in pairs( _readtimes ) do + if os_difftime( _currenttime - timestamp ) > _readtimeout then + --_readtimes[ handler ] = nil + handler.disconnect( )( handler, "read timeout" ) + handler:close( ) -- forced disconnect? + end + end + end + + -- Fire timers if _currenttime - _timer >= math_min(next_timer_time, 1) then next_timer_time = math_huge; for i = 1, _timerlistlen do @@ -844,8 +886,9 @@ loop = function(once) -- this is the main loop of the program else next_timer_time = next_timer_time - (_currenttime - _timer); end - socket_sleep( _sleeptime ) -- wait some time - --collectgarbage( ) + + -- wait some time (0 by default) + socket_sleep( _sleeptime ) until quitting; if once and quitting == "once" then quitting = nil; return; end return "quitting" @@ -862,7 +905,8 @@ end --// EXPERIMENTAL //-- local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx ) - local handler = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + local handler, socket, err = wrapconnection( nil, listeners, socket, ip, serverport, "clientport", pattern, sslctx ) + if not handler then return nil, err end _socketlist[ socket ] = handler if not sslctx then _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) @@ -908,28 +952,6 @@ use "setmetatable" ( _writetimes, { __mode = "k" } ) _timer = luasocket_gettime( ) _starttime = luasocket_gettime( ) -addtimer( function( ) - local difftime = os_difftime( _currenttime - _starttime ) - if difftime > _checkinterval then - _starttime = _currenttime - for handler, timestamp in pairs( _writetimes ) do - if os_difftime( _currenttime - timestamp ) > _sendtimeout then - --_writetimes[ handler ] = nil - handler.disconnect( )( handler, "send timeout" ) - handler:force_close() -- forced disconnect - end - end - for handler, timestamp in pairs( _readtimes ) do - if os_difftime( _currenttime - timestamp ) > _readtimeout then - --_readtimes[ handler ] = nil - handler.disconnect( )( handler, "read timeout" ) - handler:close( ) -- forced disconnect? - end - end - end - end -) - local function setlogger(new_logger) local old_logger = log; if new_logger then diff --git a/plugins/mod_admin_adhoc.lua b/plugins/mod_admin_adhoc.lua index f136eb46..31c4bde4 100644 --- a/plugins/mod_admin_adhoc.lua +++ b/plugins/mod_admin_adhoc.lua @@ -10,8 +10,9 @@ local prosody = _G.prosody; local hosts = prosody.hosts; local t_concat = table.concat; -local iterators = require "util.iterators"; -local keys, values = iterators.keys, iterators.values; +local module_host = module:get_host(); + +local keys = require "util.iterators".keys; local usermanager_user_exists = require "core.usermanager".user_exists; local usermanager_create_user = require "core.usermanager".create_user; local usermanager_delete_user = require "core.usermanager".delete_user; @@ -19,14 +20,15 @@ local usermanager_get_password = require "core.usermanager".get_password; local usermanager_set_password = require "core.usermanager".set_password; local hostmanager_activate = require "core.hostmanager".activate; local hostmanager_deactivate = require "core.hostmanager".deactivate; -local is_admin = require "core.usermanager".is_admin; local rm_load_roster = require "core.rostermanager".load_roster; -local st, jid, uuid = require "util.stanza", require "util.jid", require "util.uuid"; +local st, jid = require "util.stanza", require "util.jid"; local timer_add_task = require "util.timer".add_task; local dataforms_new = require "util.dataforms".new; local array = require "util.array"; local modulemanager = require "modulemanager"; local core_post_stanza = prosody.core_post_stanza; +local adhoc_simple = require "util.adhoc".new_simple_form; +local adhoc_initial = require "util.adhoc".new_initial_data_form; module:depends("adhoc"); local adhoc_new = module:require "adhoc".new; @@ -39,82 +41,69 @@ local function generate_error_message(errors) return { status = "completed", error = { message = t_concat(errmsg, "\n") } }; end -function add_user_command_handler(self, data, state) - local add_user_layout = dataforms_new{ - title = "Adding a User"; - instructions = "Fill out this form to add a user."; +-- Adding a new user +local add_user_layout = dataforms_new{ + title = "Adding a User"; + instructions = "Fill out this form to add a user."; - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for the account to be added" }; - { name = "password", type = "text-private", label = "The password for this account" }; - { name = "password-verify", type = "text-private", label = "Retype password" }; - }; + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for the account to be added" }; + { name = "password", type = "text-private", label = "The password for this account" }; + { name = "password-verify", type = "text-private", label = "Retype password" }; +}; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields, err = add_user_layout:data(data.form); - if err then - return generate_error_message(err); - end - local username, host, resource = jid.split(fields.accountjid); - if data.to ~= host then - return { status = "completed", error = { message = "Trying to add a user on " .. host .. " but command was sent to " .. data.to}}; - end - if (fields["password"] == fields["password-verify"]) and username and host then - if usermanager_user_exists(username, host) then - return { status = "completed", error = { message = "Account already exists" } }; +local add_user_command_handler = adhoc_simple(add_user_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local username, host, resource = jid.split(fields.accountjid); + if module_host ~= host then + return { status = "completed", error = { message = "Trying to add a user on " .. host .. " but command was sent to " .. module_host}}; + end + if (fields["password"] == fields["password-verify"]) and username and host then + if usermanager_user_exists(username, host) then + return { status = "completed", error = { message = "Account already exists" } }; + else + if usermanager_create_user(username, fields.password, host) then + module:log("info", "Created new account %s@%s", username, host); + return { status = "completed", info = "Account successfully created" }; else - if usermanager_create_user(username, fields.password, host) then - module:log("info", "Created new account %s@%s", username, host); - return { status = "completed", info = "Account successfully created" }; - else - return { status = "completed", error = { message = "Failed to write data to disk" } }; - end + return { status = "completed", error = { message = "Failed to write data to disk" } }; end - else - module:log("debug", "Invalid data, password mismatch or empty username while creating account for %s", fields.accountjid or "<nil>"); - return { status = "completed", error = { message = "Invalid data.\nPassword mismatch, or empty username" } }; end else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = add_user_layout }, "executing"; + module:log("debug", "Invalid data, password mismatch or empty username while creating account for %s", fields.accountjid or "<nil>"); + return { status = "completed", error = { message = "Invalid data.\nPassword mismatch, or empty username" } }; end -end +end); -function change_user_password_command_handler(self, data, state) - local change_user_password_layout = dataforms_new{ - title = "Changing a User Password"; - instructions = "Fill out this form to change a user's password."; +-- Changing a user's password +local change_user_password_layout = dataforms_new{ + title = "Changing a User Password"; + instructions = "Fill out this form to change a user's password."; - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for this account" }; - { name = "password", type = "text-private", required = true, label = "The password for this account" }; - }; + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for this account" }; + { name = "password", type = "text-private", required = true, label = "The password for this account" }; +}; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields, err = change_user_password_layout:data(data.form); - if err then - return generate_error_message(err); - end - local username, host, resource = jid.split(fields.accountjid); - if data.to ~= host then - return { status = "completed", error = { message = "Trying to change the password of a user on " .. host .. " but command was sent to " .. data.to}}; - end - if usermanager_user_exists(username, host) and usermanager_set_password(username, fields.password, host) then - return { status = "completed", info = "Password successfully changed" }; - else - return { status = "completed", error = { message = "User does not exist" } }; - end +local change_user_password_command_handler = adhoc_simple(change_user_password_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local username, host, resource = jid.split(fields.accountjid); + if module_host ~= host then + return { status = "completed", error = { message = "Trying to change the password of a user on " .. host .. " but command was sent to " .. module_host}}; + end + if usermanager_user_exists(username, host) and usermanager_set_password(username, fields.password, host) then + return { status = "completed", info = "Password successfully changed" }; else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = change_user_password_layout }, "executing"; + return { status = "completed", error = { message = "User does not exist" } }; end -end +end); -function config_reload_handler(self, data, state) +-- Reloading the config +local function config_reload_handler(self, data, state) local ok, err = prosody.reload_config(); if ok then return { status = "completed", info = "Configuration reloaded (modules may need to be reloaded for this to have an effect)" }; @@ -123,46 +112,39 @@ function config_reload_handler(self, data, state) end end +-- Deleting a user's account +local delete_user_layout = dataforms_new{ + title = "Deleting a User"; + instructions = "Fill out this form to delete a user."; -function delete_user_command_handler(self, data, state) - local delete_user_layout = dataforms_new{ - title = "Deleting a User"; - instructions = "Fill out this form to delete a user."; + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjids", type = "jid-multi", label = "The Jabber ID(s) to delete" }; +}; - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjids", type = "jid-multi", label = "The Jabber ID(s) to delete" }; - }; - - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields, err = delete_user_layout:data(data.form); - if err then - return generate_error_message(err); - end - local failed = {}; - local succeeded = {}; - for _, aJID in ipairs(fields.accountjids) do - local username, host, resource = jid.split(aJID); - if (host == data.to) and usermanager_user_exists(username, host) and usermanager_delete_user(username, host) then - module:log("debug", "User %s has been deleted", aJID); - succeeded[#succeeded+1] = aJID; - else - module:log("debug", "Tried to delete non-existant user %s", aJID); - failed[#failed+1] = aJID; - end +local delete_user_command_handler = adhoc_simple(delete_user_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local failed = {}; + local succeeded = {}; + for _, aJID in ipairs(fields.accountjids) do + local username, host, resource = jid.split(aJID); + if (host == module_host) and usermanager_user_exists(username, host) and usermanager_delete_user(username, host) then + module:log("debug", "User %s has been deleted", aJID); + succeeded[#succeeded+1] = aJID; + else + module:log("debug", "Tried to delete non-existant user %s", aJID); + failed[#failed+1] = aJID; end - return {status = "completed", info = (#succeeded ~= 0 and - "The following accounts were successfully deleted:\n"..t_concat(succeeded, "\n").."\n" or "").. - (#failed ~= 0 and - "The following accounts could not be deleted:\n"..t_concat(failed, "\n") or "") }; - else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = delete_user_layout }, "executing"; end -end - -function disconnect_user(match_jid) + return {status = "completed", info = (#succeeded ~= 0 and + "The following accounts were successfully deleted:\n"..t_concat(succeeded, "\n").."\n" or "").. + (#failed ~= 0 and + "The following accounts could not be deleted:\n"..t_concat(failed, "\n") or "") }; +end); + +-- Ending a user's session +local function disconnect_user(match_jid) local node, hostname, givenResource = jid.split(match_jid); local host = hosts[hostname]; local sessions = host.sessions[node] and host.sessions[node].sessions; @@ -175,447 +157,382 @@ function disconnect_user(match_jid) return true; end -function end_user_session_handler(self, data, state) - local end_user_session_layout = dataforms_new{ - title = "Ending a User Session"; - instructions = "Fill out this form to end a user's session."; +local end_user_session_layout = dataforms_new{ + title = "Ending a User Session"; + instructions = "Fill out this form to end a user's session."; - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjids", type = "jid-multi", label = "The Jabber ID(s) for which to end sessions" }; - }; - - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjids", type = "jid-multi", label = "The Jabber ID(s) for which to end sessions" }; +}; - local fields, err = end_user_session_layout:data(data.form); - if err then - return generate_error_message(err); - end - local failed = {}; - local succeeded = {}; - for _, aJID in ipairs(fields.accountjids) do - local username, host, resource = jid.split(aJID); - if (host == data.to) and usermanager_user_exists(username, host) and disconnect_user(aJID) then - succeeded[#succeeded+1] = aJID; - else - failed[#failed+1] = aJID; - end - end - return {status = "completed", info = (#succeeded ~= 0 and - "The following accounts were successfully disconnected:\n"..t_concat(succeeded, "\n").."\n" or "").. - (#failed ~= 0 and - "The following accounts could not be disconnected:\n"..t_concat(failed, "\n") or "") }; - else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = end_user_session_layout }, "executing"; +local end_user_session_handler = adhoc_simple(end_user_session_layout, function(fields, err) + if err then + return generate_error_message(err); end -end - -function get_user_password_handler(self, data, state) - local get_user_password_layout = dataforms_new{ - title = "Getting User's Password"; - instructions = "Fill out this form to get a user's password."; - - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for which to retrieve the password" }; - }; - - local get_user_password_result_layout = dataforms_new{ - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", label = "JID" }; - { name = "password", type = "text-single", label = "Password" }; - }; - - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields, err = get_user_password_layout:data(data.form); - if err then - return generate_error_message(err); - end - local user, host, resource = jid.split(fields.accountjid); - local accountjid = ""; - local password = ""; - if host ~= data.to then - return { status = "completed", error = { message = "Tried to get password for a user on " .. host .. " but command was sent to " .. data.to } }; - elseif usermanager_user_exists(user, host) then - accountjid = fields.accountjid; - password = usermanager_get_password(user, host); + local failed = {}; + local succeeded = {}; + for _, aJID in ipairs(fields.accountjids) do + local username, host, resource = jid.split(aJID); + if (host == module_host) and usermanager_user_exists(username, host) and disconnect_user(aJID) then + succeeded[#succeeded+1] = aJID; else - return { status = "completed", error = { message = "User does not exist" } }; + failed[#failed+1] = aJID; end - return { status = "completed", result = { layout = get_user_password_result_layout, values = {accountjid = accountjid, password = password} } }; + end + return {status = "completed", info = (#succeeded ~= 0 and + "The following accounts were successfully disconnected:\n"..t_concat(succeeded, "\n").."\n" or "").. + (#failed ~= 0 and + "The following accounts could not be disconnected:\n"..t_concat(failed, "\n") or "") }; +end); + +-- Getting a user's password +local get_user_password_layout = dataforms_new{ + title = "Getting User's Password"; + instructions = "Fill out this form to get a user's password."; + + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for which to retrieve the password" }; +}; + +local get_user_password_result_layout = dataforms_new{ + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", label = "JID" }; + { name = "password", type = "text-single", label = "Password" }; +}; + +local get_user_password_handler = adhoc_simple(get_user_password_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local user, host, resource = jid.split(fields.accountjid); + local accountjid = ""; + local password = ""; + if host ~= module_host then + return { status = "completed", error = { message = "Tried to get password for a user on " .. host .. " but command was sent to " .. module_host } }; + elseif usermanager_user_exists(user, host) then + accountjid = fields.accountjid; + password = usermanager_get_password(user, host); else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = get_user_password_layout }, "executing"; + return { status = "completed", error = { message = "User does not exist" } }; + end + return { status = "completed", result = { layout = get_user_password_result_layout, values = {accountjid = accountjid, password = password} } }; +end); + +-- Getting a user's roster +local get_user_roster_layout = dataforms_new{ + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for which to retrieve the roster" }; +}; + +local get_user_roster_result_layout = dataforms_new{ + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", label = "This is the roster for" }; + { name = "roster", type = "text-multi", label = "Roster XML" }; +}; + +local get_user_roster_handler = adhoc_simple(get_user_roster_layout, function(fields, err) + if err then + return generate_error_message(err); end -end - -function get_user_roster_handler(self, data, state) - local get_user_roster_layout = dataforms_new{ - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for which to retrieve the roster" }; - }; - - local get_user_roster_result_layout = dataforms_new{ - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", label = "This is the roster for" }; - { name = "roster", type = "text-multi", label = "Roster XML" }; - }; - - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - - local fields, err = get_user_roster_layout:data(data.form); - - if err then - return generate_error_message(err); - end - local user, host, resource = jid.split(fields.accountjid); - if host ~= data.to then - return { status = "completed", error = { message = "Tried to get roster for a user on " .. host .. " but command was sent to " .. data.to } }; - elseif not usermanager_user_exists(user, host) then - return { status = "completed", error = { message = "User does not exist" } }; - end - local roster = rm_load_roster(user, host); - - local query = st.stanza("query", { xmlns = "jabber:iq:roster" }); - for jid in pairs(roster) do - if jid ~= "pending" and jid then - query:tag("item", { - jid = jid, - subscription = roster[jid].subscription, - ask = roster[jid].ask, - name = roster[jid].name, - }); - for group in pairs(roster[jid].groups) do - query:tag("group"):text(group):up(); - end - query:up(); + local user, host, resource = jid.split(fields.accountjid); + if host ~= module_host then + return { status = "completed", error = { message = "Tried to get roster for a user on " .. host .. " but command was sent to " .. module_host } }; + elseif not usermanager_user_exists(user, host) then + return { status = "completed", error = { message = "User does not exist" } }; + end + local roster = rm_load_roster(user, host); + + local query = st.stanza("query", { xmlns = "jabber:iq:roster" }); + for jid in pairs(roster) do + if jid ~= "pending" and jid then + query:tag("item", { + jid = jid, + subscription = roster[jid].subscription, + ask = roster[jid].ask, + name = roster[jid].name, + }); + for group in pairs(roster[jid].groups) do + query:tag("group"):text(group):up(); end + query:up(); end - - local query_text = tostring(query):gsub("><", ">\n<"); - - local result = get_user_roster_result_layout:form({ accountjid = user.."@"..host, roster = query_text }, "result"); - result:add_child(query); - return { status = "completed", other = result }; - else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = get_user_roster_layout }, "executing"; end -end -function get_user_stats_handler(self, data, state) - local get_user_stats_layout = dataforms_new{ - title = "Get User Statistics"; - instructions = "Fill out this form to gather user statistics."; + local query_text = tostring(query):gsub("><", ">\n<"); - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for statistics" }; - }; + local result = get_user_roster_result_layout:form({ accountjid = user.."@"..host, roster = query_text }, "result"); + result:add_child(query); + return { status = "completed", other = result }; +end); - local get_user_stats_result_layout = dataforms_new{ - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "ipaddresses", type = "text-multi", label = "IP Addresses" }; - { name = "rostersize", type = "text-single", label = "Roster size" }; - { name = "onlineresources", type = "text-multi", label = "Online Resources" }; - }; +-- Getting user statistics +local get_user_stats_layout = dataforms_new{ + title = "Get User Statistics"; + instructions = "Fill out this form to gather user statistics."; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - - local fields, err = get_user_stats_layout:data(data.form); + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "accountjid", type = "jid-single", required = true, label = "The Jabber ID for statistics" }; +}; - if err then - return generate_error_message(err); - end +local get_user_stats_result_layout = dataforms_new{ + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "ipaddresses", type = "text-multi", label = "IP Addresses" }; + { name = "rostersize", type = "text-single", label = "Roster size" }; + { name = "onlineresources", type = "text-multi", label = "Online Resources" }; +}; - local user, host, resource = jid.split(fields.accountjid); - if host ~= data.to then - return { status = "completed", error = { message = "Tried to get stats for a user on " .. host .. " but command was sent to " .. data.to } }; - elseif not usermanager_user_exists(user, host) then - return { status = "completed", error = { message = "User does not exist" } }; - end - local roster = rm_load_roster(user, host); - local rostersize = 0; - local IPs = ""; - local resources = ""; - for jid in pairs(roster) do - if jid ~= "pending" and jid then - rostersize = rostersize + 1; - end - end - for resource, session in pairs((hosts[host].sessions[user] and hosts[host].sessions[user].sessions) or {}) do - resources = resources .. "\n" .. resource; - IPs = IPs .. "\n" .. session.ip; - end - return { status = "completed", result = {layout = get_user_stats_result_layout, values = {ipaddresses = IPs, rostersize = tostring(rostersize), - onlineresources = resources}} }; - else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = get_user_stats_layout }, "executing"; +local get_user_stats_handler = adhoc_simple(get_user_stats_layout, function(fields, err) + if err then + return generate_error_message(err); end -end -function get_online_users_command_handler(self, data, state) - local get_online_users_layout = dataforms_new{ - title = "Getting List of Online Users"; - instructions = "How many users should be returned at most?"; - - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "max_items", type = "list-single", label = "Maximum number of users", - value = { "25", "50", "75", "100", "150", "200", "all" } }; - { name = "details", type = "boolean", label = "Show details" }; - }; - - local get_online_users_result_layout = dataforms_new{ - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "onlineuserjids", type = "text-multi", label = "The list of all online users" }; - }; - - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - - local fields, err = get_online_users_layout:data(data.form); - - if err then - return generate_error_message(err); + local user, host, resource = jid.split(fields.accountjid); + if host ~= module_host then + return { status = "completed", error = { message = "Tried to get stats for a user on " .. host .. " but command was sent to " .. module_host } }; + elseif not usermanager_user_exists(user, host) then + return { status = "completed", error = { message = "User does not exist" } }; + end + local roster = rm_load_roster(user, host); + local rostersize = 0; + local IPs = ""; + local resources = ""; + for jid in pairs(roster) do + if jid ~= "pending" and jid then + rostersize = rostersize + 1; end + end + for resource, session in pairs((hosts[host].sessions[user] and hosts[host].sessions[user].sessions) or {}) do + resources = resources .. "\n" .. resource; + IPs = IPs .. "\n" .. session.ip; + end + return { status = "completed", result = {layout = get_user_stats_result_layout, values = {ipaddresses = IPs, rostersize = tostring(rostersize), + onlineresources = resources}} }; +end); + +-- Getting a list of online users +local get_online_users_layout = dataforms_new{ + title = "Getting List of Online Users"; + instructions = "How many users should be returned at most?"; + + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "max_items", type = "list-single", label = "Maximum number of users", + value = { "25", "50", "75", "100", "150", "200", "all" } }; + { name = "details", type = "boolean", label = "Show details" }; +}; + +local get_online_users_result_layout = dataforms_new{ + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "onlineuserjids", type = "text-multi", label = "The list of all online users" }; +}; + +local get_online_users_command_handler = adhoc_simple(get_online_users_layout, function(fields, err) + if err then + return generate_error_message(err); + end - local max_items = nil - if fields.max_items ~= "all" then - max_items = tonumber(fields.max_items); - end - local count = 0; - local users = {}; - for username, user in pairs(hosts[data.to].sessions or {}) do - if (max_items ~= nil) and (count >= max_items) then - break; - end - users[#users+1] = username.."@"..data.to; - count = count + 1; - if fields.details then - for resource, session in pairs(user.sessions or {}) do - local status, priority = "unavailable", tostring(session.priority or "-"); - if session.presence then - status = session.presence:child_with_name("show"); - if status then - status = status:get_text() or "[invalid!]"; - else - status = "available"; - end + local max_items = nil + if fields.max_items ~= "all" then + max_items = tonumber(fields.max_items); + end + local count = 0; + local users = {}; + for username, user in pairs(hosts[module_host].sessions or {}) do + if (max_items ~= nil) and (count >= max_items) then + break; + end + users[#users+1] = username.."@"..module_host; + count = count + 1; + if fields.details then + for resource, session in pairs(user.sessions or {}) do + local status, priority = "unavailable", tostring(session.priority or "-"); + if session.presence then + status = session.presence:child_with_name("show"); + if status then + status = status:get_text() or "[invalid!]"; + else + status = "available"; end - users[#users+1] = " - "..resource..": "..status.."("..priority..")"; end + users[#users+1] = " - "..resource..": "..status.."("..priority..")"; end end - return { status = "completed", result = {layout = get_online_users_result_layout, values = {onlineuserjids=t_concat(users, "\n")}} }; - else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = get_online_users_layout }, "executing"; end -end + return { status = "completed", result = {layout = get_online_users_result_layout, values = {onlineuserjids=t_concat(users, "\n")}} }; +end); -function list_modules_handler(self, data, state) - local result = dataforms_new { - title = "List of loaded modules"; +-- Getting a list of loaded modules +local list_modules_result = dataforms_new { + title = "List of loaded modules"; - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#list" }; - { name = "modules", type = "text-multi", label = "The following modules are loaded:" }; - }; - - local modules = array.collect(keys(hosts[data.to].modules)):sort():concat("\n"); + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#list" }; + { name = "modules", type = "text-multi", label = "The following modules are loaded:" }; +}; - return { status = "completed", result = { layout = result; values = { modules = modules } } }; +local function list_modules_handler(self, data, state) + local modules = array.collect(keys(hosts[module_host].modules)):sort():concat("\n"); + return { status = "completed", result = { layout = list_modules_result; values = { modules = modules } } }; end -function load_module_handler(self, data, state) - local layout = dataforms_new { - title = "Load module"; - instructions = "Specify the module to be loaded"; +-- Loading a module +local load_module_layout = dataforms_new { + title = "Load module"; + instructions = "Specify the module to be loaded"; - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#load" }; - { name = "module", type = "text-single", required = true, label = "Module to be loaded:"}; - }; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields, err = layout:data(data.form); - if err then - return generate_error_message(err); - end - if modulemanager.is_loaded(data.to, fields.module) then - return { status = "completed", info = "Module already loaded" }; - end - local ok, err = modulemanager.load(data.to, fields.module); - if ok then - return { status = "completed", info = 'Module "'..fields.module..'" successfully loaded on host "'..data.to..'".' }; - else - return { status = "completed", error = { message = 'Failed to load module "'..fields.module..'" on host "'..data.to.. - '". Error was: "'..tostring(err or "<unspecified>")..'"' } }; - end + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#load" }; + { name = "module", type = "text-single", required = true, label = "Module to be loaded:"}; +}; + +local load_module_handler = adhoc_simple(load_module_layout, function(fields, err) + if err then + return generate_error_message(err); + end + if modulemanager.is_loaded(module_host, fields.module) then + return { status = "completed", info = "Module already loaded" }; + end + local ok, err = modulemanager.load(module_host, fields.module); + if ok then + return { status = "completed", info = 'Module "'..fields.module..'" successfully loaded on host "'..module_host..'".' }; else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = layout }, "executing"; + return { status = "completed", error = { message = 'Failed to load module "'..fields.module..'" on host "'..module_host.. + '". Error was: "'..tostring(err or "<unspecified>")..'"' } }; end -end - -local function globally_load_module_handler(self, data, state) - local layout = dataforms_new { - title = "Globally load module"; - instructions = "Specify the module to be loaded on all hosts"; - - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#global-load" }; - { name = "module", type = "text-single", required = true, label = "Module to globally load:"}; - }; - if state then - local ok_list, err_list = {}, {}; +end); - if data.action == "cancel" then - return { status = "canceled" }; - end +-- Globally loading a module +local globally_load_module_layout = dataforms_new { + title = "Globally load module"; + instructions = "Specify the module to be loaded on all hosts"; - local fields, err = layout:data(data.form); - if err then - return generate_error_message(err); - end + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#global-load" }; + { name = "module", type = "text-single", required = true, label = "Module to globally load:"}; +}; - local ok, err = modulemanager.load(data.to, fields.module); - if ok then - ok_list[#ok_list + 1] = data.to; - else - err_list[#err_list + 1] = data.to .. " (Error: " .. tostring(err) .. ")"; - end +local globally_load_module_handler = adhoc_simple(globally_load_module_layout, function(fields, err) + local ok_list, err_list = {}, {}; - -- Is this a global module? - if modulemanager.is_loaded("*", fields.module) and not modulemanager.is_loaded(data.to, fields.module) then - return { status = "completed", info = 'Global module '..fields.module..' loaded.' }; - end - - -- This is either a shared or "normal" module, load it on all other hosts - for host_name, host in pairs(hosts) do - if host_name ~= data.to and host.type == "local" then - local ok, err = modulemanager.load(host_name, fields.module); - if ok then - ok_list[#ok_list + 1] = host_name; - else - err_list[#err_list + 1] = host_name .. " (Error: " .. tostring(err) .. ")"; - end - end - end + if err then + return generate_error_message(err); + end - local info = (#ok_list > 0 and ("The module "..fields.module.." was successfully loaded onto the hosts:\n"..t_concat(ok_list, "\n")) or "") - .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. - (#err_list > 0 and ("Failed to load the module "..fields.module.." onto the hosts:\n"..t_concat(err_list, "\n")) or ""); - return { status = "completed", info = info }; + local ok, err = modulemanager.load(module_host, fields.module); + if ok then + ok_list[#ok_list + 1] = module_host; else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = layout }, "executing"; + err_list[#err_list + 1] = module_host .. " (Error: " .. tostring(err) .. ")"; end -end -function reload_modules_handler(self, data, state) - local layout = dataforms_new { - title = "Reload modules"; - instructions = "Select the modules to be reloaded"; + -- Is this a global module? + if modulemanager.is_loaded("*", fields.module) and not modulemanager.is_loaded(module_host, fields.module) then + return { status = "completed", info = 'Global module '..fields.module..' loaded.' }; + end - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#reload" }; - { name = "modules", type = "list-multi", required = true, label = "Modules to be reloaded:"}; - }; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields, err = layout:data(data.form); - if err then - return generate_error_message(err); - end - local ok_list, err_list = {}, {}; - for _, module in ipairs(fields.modules) do - local ok, err = modulemanager.reload(data.to, module); + -- This is either a shared or "normal" module, load it on all other hosts + for host_name, host in pairs(hosts) do + if host_name ~= module_host and host.type == "local" then + local ok, err = modulemanager.load(host_name, fields.module); if ok then - ok_list[#ok_list + 1] = module; + ok_list[#ok_list + 1] = host_name; else - err_list[#err_list + 1] = module .. "(Error: " .. tostring(err) .. ")"; + err_list[#err_list + 1] = host_name .. " (Error: " .. tostring(err) .. ")"; end end - local info = (#ok_list > 0 and ("The following modules were successfully reloaded on host "..data.to..":\n"..t_concat(ok_list, "\n")) or "") - .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. - (#err_list > 0 and ("Failed to reload the following modules on host "..data.to..":\n"..t_concat(err_list, "\n")) or ""); - return { status = "completed", info = info }; - else - local modules = array.collect(keys(hosts[data.to].modules)):sort(); - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = { layout = layout; values = { modules = modules } } }, "executing"; end -end -local function globally_reload_module_handler(self, data, state) - local layout = dataforms_new { - title = "Globally reload module"; - instructions = "Specify the module to reload on all hosts"; - - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#global-reload" }; - { name = "module", type = "list-single", required = true, label = "Module to globally reload:"}; - }; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - - local is_global = false; - local fields, err = layout:data(data.form); - if err then - return generate_error_message(err); + local info = (#ok_list > 0 and ("The module "..fields.module.." was successfully loaded onto the hosts:\n"..t_concat(ok_list, "\n")) or "") + .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. + (#err_list > 0 and ("Failed to load the module "..fields.module.." onto the hosts:\n"..t_concat(err_list, "\n")) or ""); + return { status = "completed", info = info }; +end); + +-- Reloading modules +local reload_modules_layout = dataforms_new { + title = "Reload modules"; + instructions = "Select the modules to be reloaded"; + + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#reload" }; + { name = "modules", type = "list-multi", required = true, label = "Modules to be reloaded:"}; +}; + +local reload_modules_handler = adhoc_initial(reload_modules_layout, function() + return { modules = array.collect(keys(hosts[module_host].modules)):sort() }; +end, function(fields, err) + if err then + return generate_error_message(err); + end + local ok_list, err_list = {}, {}; + for _, module in ipairs(fields.modules) do + local ok, err = modulemanager.reload(module_host, module); + if ok then + ok_list[#ok_list + 1] = module; + else + err_list[#err_list + 1] = module .. "(Error: " .. tostring(err) .. ")"; end + end + local info = (#ok_list > 0 and ("The following modules were successfully reloaded on host "..module_host..":\n"..t_concat(ok_list, "\n")) or "") + .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. + (#err_list > 0 and ("Failed to reload the following modules on host "..module_host..":\n"..t_concat(err_list, "\n")) or ""); + return { status = "completed", info = info }; +end); + +-- Globally reloading a module +local globally_reload_module_layout = dataforms_new { + title = "Globally reload module"; + instructions = "Specify the module to reload on all hosts"; + + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#global-reload" }; + { name = "module", type = "list-single", required = true, label = "Module to globally reload:"}; +}; + +local globally_reload_module_handler = adhoc_initial(globally_reload_module_layout, function() + local loaded_modules = array(keys(modulemanager.get_modules("*"))); + for _, host in pairs(hosts) do + loaded_modules:append(array(keys(host.modules))); + end + loaded_modules = array(keys(set.new(loaded_modules):items())):sort(); + return { module = loaded_modules }; +end, function(fields, err) + local is_global = false; - if modulemanager.is_loaded("*", fields.module) then - local ok, err = modulemanager.reload("*", fields.module); - if not ok then - return { status = "completed", info = 'Global module '..fields.module..' failed to reload: '..err }; - end - is_global = true; - end + if err then + return generate_error_message(err); + end - local ok_list, err_list = {}, {}; - for host_name, host in pairs(hosts) do - if modulemanager.is_loaded(host_name, fields.module) then - local ok, err = modulemanager.reload(host_name, fields.module); - if ok then - ok_list[#ok_list + 1] = host_name; - else - err_list[#err_list + 1] = host_name .. " (Error: " .. tostring(err) .. ")"; - end - end + if modulemanager.is_loaded("*", fields.module) then + local ok, err = modulemanager.reload("*", fields.module); + if not ok then + return { status = "completed", info = 'Global module '..fields.module..' failed to reload: '..err }; end + is_global = true; + end - if #ok_list == 0 and #err_list == 0 then - if is_global then - return { status = "completed", info = 'Successfully reloaded global module '..fields.module }; + local ok_list, err_list = {}, {}; + for host_name, host in pairs(hosts) do + if modulemanager.is_loaded(host_name, fields.module) then + local ok, err = modulemanager.reload(host_name, fields.module); + if ok then + ok_list[#ok_list + 1] = host_name; else - return { status = "completed", info = 'Module '..fields.module..' not loaded on any host.' }; + err_list[#err_list + 1] = host_name .. " (Error: " .. tostring(err) .. ")"; end end + end - local info = (#ok_list > 0 and ("The module "..fields.module.." was successfully reloaded on the hosts:\n"..t_concat(ok_list, "\n")) or "") - .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. - (#err_list > 0 and ("Failed to reload the module "..fields.module.." on the hosts:\n"..t_concat(err_list, "\n")) or ""); - return { status = "completed", info = info }; - else - local loaded_modules = array(keys(modulemanager.get_modules("*"))); - for _, host in pairs(hosts) do - loaded_modules:append(array(keys(host.modules))); + if #ok_list == 0 and #err_list == 0 then + if is_global then + return { status = "completed", info = 'Successfully reloaded global module '..fields.module }; + else + return { status = "completed", info = 'Module '..fields.module..' not loaded on any host.' }; end - loaded_modules = array(keys(set.new(loaded_modules):items())):sort(); - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = { layout = layout, values = { module = loaded_modules } } }, "executing"; end -end -function send_to_online(message, server) + local info = (#ok_list > 0 and ("The module "..fields.module.." was successfully reloaded on the hosts:\n"..t_concat(ok_list, "\n")) or "") + .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. + (#err_list > 0 and ("Failed to reload the module "..fields.module.." on the hosts:\n"..t_concat(err_list, "\n")) or ""); + return { status = "completed", info = info }; +end); + +local function send_to_online(message, server) if server then sessions = { [server] = hosts[server] }; else @@ -635,202 +552,170 @@ function send_to_online(message, server) return c; end -function shut_down_service_handler(self, data, state) - local shut_down_service_layout = dataforms_new{ - title = "Shutting Down the Service"; - instructions = "Fill out this form to shut down the service."; - - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; - { name = "delay", type = "list-single", label = "Time delay before shutting down", - value = { {label = "30 seconds", value = "30"}, - {label = "60 seconds", value = "60"}, - {label = "90 seconds", value = "90"}, - {label = "2 minutes", value = "120"}, - {label = "3 minutes", value = "180"}, - {label = "4 minutes", value = "240"}, - {label = "5 minutes", value = "300"}, - }; +-- Shutting down the service +local shut_down_service_layout = dataforms_new{ + title = "Shutting Down the Service"; + instructions = "Fill out this form to shut down the service."; + + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/admin" }; + { name = "delay", type = "list-single", label = "Time delay before shutting down", + value = { {label = "30 seconds", value = "30"}, + {label = "60 seconds", value = "60"}, + {label = "90 seconds", value = "90"}, + {label = "2 minutes", value = "120"}, + {label = "3 minutes", value = "180"}, + {label = "4 minutes", value = "240"}, + {label = "5 minutes", value = "300"}, }; - { name = "announcement", type = "text-multi", label = "Announcement" }; }; + { name = "announcement", type = "text-multi", label = "Announcement" }; +}; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end +local shut_down_service_handler = adhoc_simple(shut_down_service_layout, function(fields, err) + if err then + return generate_error_message(err); + end - local fields, err = shut_down_service_layout:data(data.form); + if fields.announcement and #fields.announcement > 0 then + local message = st.message({type = "headline"}, fields.announcement):up() + :tag("subject"):text("Server is shutting down"); + send_to_online(message); + end - if err then - return generate_error_message(err); - end + timer_add_task(tonumber(fields.delay or "5"), function(time) prosody.shutdown("Shutdown by adhoc command") end); - if fields.announcement and #fields.announcement > 0 then - local message = st.message({type = "headline"}, fields.announcement):up() - :tag("subject"):text("Server is shutting down"); - send_to_online(message); - end + return { status = "completed", info = "Server is about to shut down" }; +end); - timer_add_task(tonumber(fields.delay or "5"), function(time) prosody.shutdown("Shutdown by adhoc command") end); +-- Unloading modules +local unload_modules_layout = dataforms_new { + title = "Unload modules"; + instructions = "Select the modules to be unloaded"; - return { status = "completed", info = "Server is about to shut down" }; - else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = shut_down_service_layout }, "executing"; - end -end - -function unload_modules_handler(self, data, state) - local layout = dataforms_new { - title = "Unload modules"; - instructions = "Select the modules to be unloaded"; + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#unload" }; + { name = "modules", type = "list-multi", required = true, label = "Modules to be unloaded:"}; +}; - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#unload" }; - { name = "modules", type = "list-multi", required = true, label = "Modules to be unloaded:"}; - }; - if state then - if data.action == "cancel" then - return { status = "canceled" }; +local unload_modules_handler = adhoc_initial(unload_modules_layout, function() + return { modules = array.collect(keys(hosts[module_host].modules)):sort() }; +end, function(fields, err) + if err then + return generate_error_message(err); + end + local ok_list, err_list = {}, {}; + for _, module in ipairs(fields.modules) do + local ok, err = modulemanager.unload(module_host, module); + if ok then + ok_list[#ok_list + 1] = module; + else + err_list[#err_list + 1] = module .. "(Error: " .. tostring(err) .. ")"; end - local fields, err = layout:data(data.form); - if err then - return generate_error_message(err); + end + local info = (#ok_list > 0 and ("The following modules were successfully unloaded on host "..module_host..":\n"..t_concat(ok_list, "\n")) or "") + .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. + (#err_list > 0 and ("Failed to unload the following modules on host "..module_host..":\n"..t_concat(err_list, "\n")) or ""); + return { status = "completed", info = info }; +end); + +-- Globally unloading a module +local globally_unload_module_layout = dataforms_new { + title = "Globally unload module"; + instructions = "Specify a module to unload on all hosts"; + + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#global-unload" }; + { name = "module", type = "list-single", required = true, label = "Module to globally unload:"}; +}; + +local globally_unload_module_handler = adhoc_initial(globally_unload_module_layout, function() + local loaded_modules = array(keys(modulemanager.get_modules("*"))); + for _, host in pairs(hosts) do + loaded_modules:append(array(keys(host.modules))); + end + loaded_modules = array(keys(set.new(loaded_modules):items())):sort(); + return { module = loaded_modules }; +end, function(fields, err) + local is_global = false; + if err then + return generate_error_message(err); + end + + if modulemanager.is_loaded("*", fields.module) then + local ok, err = modulemanager.unload("*", fields.module); + if not ok then + return { status = "completed", info = 'Global module '..fields.module..' failed to unload: '..err }; end - local ok_list, err_list = {}, {}; - for _, module in ipairs(fields.modules) do - local ok, err = modulemanager.unload(data.to, module); + is_global = true; + end + + local ok_list, err_list = {}, {}; + for host_name, host in pairs(hosts) do + if modulemanager.is_loaded(host_name, fields.module) then + local ok, err = modulemanager.unload(host_name, fields.module); if ok then - ok_list[#ok_list + 1] = module; + ok_list[#ok_list + 1] = host_name; else - err_list[#err_list + 1] = module .. "(Error: " .. tostring(err) .. ")"; + err_list[#err_list + 1] = host_name .. " (Error: " .. tostring(err) .. ")"; end end - local info = (#ok_list > 0 and ("The following modules were successfully unloaded on host "..data.to..":\n"..t_concat(ok_list, "\n")) or "") - .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. - (#err_list > 0 and ("Failed to unload the following modules on host "..data.to..":\n"..t_concat(err_list, "\n")) or ""); - return { status = "completed", info = info }; - else - local modules = array.collect(keys(hosts[data.to].modules)):sort(); - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = { layout = layout; values = { modules = modules } } }, "executing"; end -end - -local function globally_unload_module_handler(self, data, state) - local layout = dataforms_new { - title = "Globally unload module"; - instructions = "Specify a module to unload on all hosts"; - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/modules#global-unload" }; - { name = "module", type = "list-single", required = true, label = "Module to globally unload:"}; - }; - if state then - if data.action == "cancel" then - return { status = "canceled" }; + if #ok_list == 0 and #err_list == 0 then + if is_global then + return { status = "completed", info = 'Successfully unloaded global module '..fields.module }; + else + return { status = "completed", info = 'Module '..fields.module..' not loaded on any host.' }; end + end - local is_global = false; - local fields, err = layout:data(data.form); - if err then - return generate_error_message(err); - end + local info = (#ok_list > 0 and ("The module "..fields.module.." was successfully unloaded on the hosts:\n"..t_concat(ok_list, "\n")) or "") + .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. + (#err_list > 0 and ("Failed to unload the module "..fields.module.." on the hosts:\n"..t_concat(err_list, "\n")) or ""); + return { status = "completed", info = info }; +end); - if modulemanager.is_loaded("*", fields.module) then - local ok, err = modulemanager.unload("*", fields.module); - if not ok then - return { status = "completed", info = 'Global module '..fields.module..' failed to unload: '..err }; - end - is_global = true; - end +-- Activating a host +local activate_host_layout = dataforms_new { + title = "Activate host"; + instructions = ""; - local ok_list, err_list = {}, {}; - for host_name, host in pairs(hosts) do - if modulemanager.is_loaded(host_name, fields.module) then - local ok, err = modulemanager.unload(host_name, fields.module); - if ok then - ok_list[#ok_list + 1] = host_name; - else - err_list[#err_list + 1] = host_name .. " (Error: " .. tostring(err) .. ")"; - end - end - end + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/hosts#activate" }; + { name = "host", type = "text-single", required = true, label = "Host:"}; +}; - if #ok_list == 0 and #err_list == 0 then - if is_global then - return { status = "completed", info = 'Successfully unloaded global module '..fields.module }; - else - return { status = "completed", info = 'Module '..fields.module..' not loaded on any host.' }; - end - end +local activate_host_handler = adhoc_simple(activate_host_layout, function(fields, err) + if err then + return generate_error_message(err); + end + local ok, err = hostmanager_activate(fields.host); - local info = (#ok_list > 0 and ("The module "..fields.module.." was successfully unloaded on the hosts:\n"..t_concat(ok_list, "\n")) or "") - .. ((#ok_list > 0 and #err_list > 0) and "\n" or "") .. - (#err_list > 0 and ("Failed to unload the module "..fields.module.." on the hosts:\n"..t_concat(err_list, "\n")) or ""); - return { status = "completed", info = info }; + if ok then + return { status = "completed", info = fields.host .. " activated" }; else - local loaded_modules = array(keys(modulemanager.get_modules("*"))); - for _, host in pairs(hosts) do - loaded_modules:append(array(keys(host.modules))); - end - loaded_modules = array(keys(set.new(loaded_modules):items())):sort(); - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = { layout = layout, values = { module = loaded_modules } } }, "executing"; + return { status = "canceled", error = err } end -end +end); +-- Deactivating a host +local deactivate_host_layout = dataforms_new { + title = "Deactivate host"; + instructions = ""; -function activate_host_handler(self, data, state) - local layout = dataforms_new { - title = "Activate host"; - instructions = ""; + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/hosts#activate" }; + { name = "host", type = "text-single", required = true, label = "Host:"}; +}; - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/hosts#activate" }; - { name = "host", type = "text-single", required = true, label = "Host:"}; - }; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields, err = layout:data(data.form); - if err then - return generate_error_message(err); - end - local ok, err = hostmanager_activate(fields.host); - - if ok then - return { status = "completed", info = fields.host .. " activated" }; - else - return { status = "canceled", error = err } - end - else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = { layout = layout } }, "executing"; +local deactivate_host_handler = adhoc_simple(deactivate_host_layout, function(fields, err) + if err then + return generate_error_message(err); end -end - -function deactivate_host_handler(self, data, state) - local layout = dataforms_new { - title = "Deactivate host"; - instructions = ""; + local ok, err = hostmanager_deactivate(fields.host); - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/hosts#activate" }; - { name = "host", type = "text-single", required = true, label = "Host:"}; - }; - if state then - if data.action == "cancel" then - return { status = "canceled" }; - end - local fields, err = layout:data(data.form); - if err then - return generate_error_message(err); - end - local ok, err = hostmanager_deactivate(fields.host); - - if ok then - return { status = "completed", info = fields.host .. " deactivated" }; - else - return { status = "canceled", error = err } - end + if ok then + return { status = "completed", info = fields.host .. " deactivated" }; else - return { status = "executing", actions = {"next", "complete", default = "complete"}, form = { layout = layout } }, "executing"; + return { status = "canceled", error = err } end -end +end); local add_user_desc = adhoc_new("Add User", "http://jabber.org/protocol/admin#add-user", add_user_command_handler, "admin"); diff --git a/plugins/mod_admin_telnet.lua b/plugins/mod_admin_telnet.lua index e1b90684..2622a5f9 100644 --- a/plugins/mod_admin_telnet.lua +++ b/plugins/mod_admin_telnet.lua @@ -903,13 +903,23 @@ local console_room_mt = { end; }; -function def_env.muc:room(room_jid) - local room_name, host = jid_split(room_jid); +local function check_muc(jid) + local room_name, host = jid_split(jid); if not hosts[host] then return nil, "No such host: "..host; elseif not hosts[host].modules.muc then return nil, "Host '"..host.."' is not a MUC service"; end + return room_name, host; +end + +function def_env.muc:create(room_jid) + local room, host = check_muc(room_jid); + return hosts[host].modules.muc.create_room(room_jid); +end + +function def_env.muc:room(room_jid) + local room_name, host = check_muc(room_jid); local room_obj = hosts[host].modules.muc.rooms[room_jid]; if not room_obj then return nil, "No such room: "..room_jid; diff --git a/plugins/mod_announce.lua b/plugins/mod_announce.lua index 0872bd21..96976d6f 100644 --- a/plugins/mod_announce.lua +++ b/plugins/mod_announce.lua @@ -8,6 +8,7 @@ local st, jid = require "util.stanza", require "util.jid"; +local hosts = prosody.hosts; local is_admin = require "core.usermanager".is_admin; function send_to_online(message, host) diff --git a/plugins/mod_auth_anonymous.lua b/plugins/mod_auth_anonymous.lua index a327f438..c877d532 100644 --- a/plugins/mod_auth_anonymous.lua +++ b/plugins/mod_auth_anonymous.lua @@ -8,6 +8,7 @@ local new_sasl = require "util.sasl".new; local datamanager = require "util.datamanager"; +local hosts = prosody.hosts; -- define auth provider local provider = {}; diff --git a/plugins/mod_auth_internal_hashed.lua b/plugins/mod_auth_internal_hashed.lua index cb6cc8ff..2b041e43 100644 --- a/plugins/mod_auth_internal_hashed.lua +++ b/plugins/mod_auth_internal_hashed.lua @@ -7,13 +7,14 @@ -- COPYING file in the source package for more information. -- -local datamanager = require "util.datamanager"; local log = require "util.logger".init("auth_internal_hashed"); local getAuthenticationDatabaseSHA1 = require "util.sasl.scram".getAuthenticationDatabaseSHA1; local usermanager = require "core.usermanager"; local generate_uuid = require "util.uuid".generate; local new_sasl = require "util.sasl".new; +local accounts = module:open_store("accounts"); + local to_hex; do local function replace_byte_with_hex(byte) @@ -44,7 +45,7 @@ local provider = {}; log("debug", "initializing internal_hashed authentication provider for host '%s'", host); function provider.test_password(username, password) - local credentials = datamanager.load(username, host, "accounts") or {}; + local credentials = accounts:get(username) or {}; if credentials.password ~= nil and string.len(credentials.password) ~= 0 then if credentials.password ~= password then @@ -75,7 +76,7 @@ function provider.test_password(username, password) end function provider.set_password(username, password) - local account = datamanager.load(username, host, "accounts"); + local account = accounts:get(username); if account then account.salt = account.salt or generate_uuid(); account.iteration_count = account.iteration_count or iteration_count; @@ -87,13 +88,13 @@ function provider.set_password(username, password) account.server_key = server_key_hex account.password = nil; - return datamanager.store(username, host, "accounts", account); + return accounts:set(username, account); end return nil, "Account not available."; end function provider.user_exists(username) - local account = datamanager.load(username, host, "accounts"); + local account = accounts:get(username); if not account then log("debug", "account not found for username '%s' at host '%s'", username, host); return nil, "Auth failed. Invalid username"; @@ -102,22 +103,22 @@ function provider.user_exists(username) end function provider.users() - return datamanager.users(host, "accounts"); + return accounts:users(); end function provider.create_user(username, password) if password == nil then - return datamanager.store(username, host, "accounts", {}); + return accounts:set(username, {}); end local salt = generate_uuid(); local valid, stored_key, server_key = getAuthenticationDatabaseSHA1(password, salt, iteration_count); local stored_key_hex = to_hex(stored_key); local server_key_hex = to_hex(server_key); - return datamanager.store(username, host, "accounts", {stored_key = stored_key_hex, server_key = server_key_hex, salt = salt, iteration_count = iteration_count}); + return accounts:set(username, {stored_key = stored_key_hex, server_key = server_key_hex, salt = salt, iteration_count = iteration_count}); end function provider.delete_user(username) - return datamanager.store(username, host, "accounts", nil); + return accounts:set(username, nil); end function provider.get_sasl_handler() @@ -126,11 +127,11 @@ function provider.get_sasl_handler() return usermanager.test_password(username, realm, password), true; end, scram_sha_1 = function(sasl, username, realm) - local credentials = datamanager.load(username, host, "accounts"); + local credentials = accounts:get(username); if not credentials then return; end if credentials.password then usermanager.set_password(username, credentials.password, host); - credentials = datamanager.load(username, host, "accounts"); + credentials = accounts:get(username); if not credentials then return; end end diff --git a/plugins/mod_auth_internal_plain.lua b/plugins/mod_auth_internal_plain.lua index 178ae5a5..d226fdbe 100644 --- a/plugins/mod_auth_internal_plain.lua +++ b/plugins/mod_auth_internal_plain.lua @@ -6,20 +6,21 @@ -- COPYING file in the source package for more information. -- -local datamanager = require "util.datamanager"; local usermanager = require "core.usermanager"; local new_sasl = require "util.sasl".new; local log = module._log; local host = module.host; +local accounts = module:open_store("accounts"); + -- define auth provider local provider = {}; log("debug", "initializing internal_plain authentication provider for host '%s'", host); function provider.test_password(username, password) - log("debug", "test password '%s' for user %s at host %s", password, username, host); - local credentials = datamanager.load(username, host, "accounts") or {}; + log("debug", "test password for user %s at host %s", username, host); + local credentials = accounts:get(username) or {}; if password == credentials.password then return true; @@ -30,20 +31,20 @@ end function provider.get_password(username) log("debug", "get_password for username '%s' at host '%s'", username, host); - return (datamanager.load(username, host, "accounts") or {}).password; + return (accounts:get(username) or {}).password; end function provider.set_password(username, password) - local account = datamanager.load(username, host, "accounts"); + local account = accounts:get(username); if account then account.password = password; - return datamanager.store(username, host, "accounts", account); + return accounts:set(username, account); end return nil, "Account not available."; end function provider.user_exists(username) - local account = datamanager.load(username, host, "accounts"); + local account = accounts:get(username); if not account then log("debug", "account not found for username '%s' at host '%s'", username, host); return nil, "Auth failed. Invalid username"; @@ -52,15 +53,15 @@ function provider.user_exists(username) end function provider.users() - return datamanager.users(host, "accounts"); + return accounts:users(); end function provider.create_user(username, password) - return datamanager.store(username, host, "accounts", {password = password}); + return accounts:set(username, {password = password}); end function provider.delete_user(username) - return datamanager.store(username, host, "accounts", nil); + return accounts:set(username, nil); end function provider.get_sasl_handler() diff --git a/plugins/mod_c2s.lua b/plugins/mod_c2s.lua index 89d678ca..efef8763 100644 --- a/plugins/mod_c2s.lua +++ b/plugins/mod_c2s.lua @@ -29,6 +29,7 @@ local opt_keepalives = module:get_option_boolean("tcp_keepalives", false); local sessions = module:shared("sessions"); local core_process_stanza = prosody.core_process_stanza; +local hosts = prosody.hosts; local stream_callbacks = { default_ns = "jabber:client", handlestanza = core_process_stanza }; local listener = {}; @@ -115,7 +116,7 @@ function stream_callbacks.error(session, error, data) end end -local function handleerr(err) log("error", "Traceback[c2s]: %s: %s", tostring(err), traceback()); end +local function handleerr(err) log("error", "Traceback[c2s]: %s", traceback(tostring(err), 2)); end function stream_callbacks.handlestanza(session, stanza) stanza = session.filter("stanzas/in", stanza); if stanza then @@ -132,25 +133,25 @@ local function session_close(session, reason) session.send(st.stanza("stream:stream", default_stream_attr):top_tag()); end if reason then -- nil == no err, initiated by us, false == initiated by client + local stream_error = st.stanza("stream:error"); if type(reason) == "string" then -- assume stream error - log("debug", "Disconnecting client, <stream:error> is: %s", reason); - session.send(st.stanza("stream:error"):tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' })); + stream_error:tag(reason, {xmlns = 'urn:ietf:params:xml:ns:xmpp-streams' }); elseif type(reason) == "table" then if reason.condition then - local stanza = st.stanza("stream:error"):tag(reason.condition, stream_xmlns_attr):up(); + stream_error:tag(reason.condition, stream_xmlns_attr):up(); if reason.text then - stanza:tag("text", stream_xmlns_attr):text(reason.text):up(); + stream_error:tag("text", stream_xmlns_attr):text(reason.text):up(); end if reason.extra then - stanza:add_child(reason.extra); + stream_error:add_child(reason.extra); end - log("debug", "Disconnecting client, <stream:error> is: %s", tostring(stanza)); - session.send(stanza); elseif reason.name then -- a stanza - log("debug", "Disconnecting client, <stream:error> is: %s", tostring(reason)); - session.send(reason); + stream_error = reason; end end + stream_error = tostring(stream_error); + log("debug", "Disconnecting client, <stream:error> is: %s", stream_error); + session.send(stream_error); end session.send("</stream:stream>"); diff --git a/plugins/mod_component.lua b/plugins/mod_component.lua index 68d8a5de..871a20e4 100644 --- a/plugins/mod_component.lua +++ b/plugins/mod_component.lua @@ -19,7 +19,7 @@ local new_xmpp_stream = require "util.xmppstream".new; local uuid_gen = require "util.uuid".generate; local core_process_stanza = prosody.core_process_stanza; - +local hosts = prosody.hosts; local log = module._log; diff --git a/plugins/mod_compression.lua b/plugins/mod_compression.lua index 67a88eb9..92856099 100644 --- a/plugins/mod_compression.lua +++ b/plugins/mod_compression.lua @@ -141,10 +141,7 @@ module:hook("stanza/http://jabber.org/protocol/compress:compressed", function(ev -- setup decompression for session.data setup_decompression(session, inflate_stream); session:reset_stream(); - local default_stream_attr = {xmlns = "jabber:server", ["xmlns:stream"] = "http://etherx.jabber.org/streams", - ["xmlns:db"] = 'jabber:server:dialback', version = "1.0", to = session.to_host, from = session.from_host}; - session.sends2s("<?xml version='1.0'?>"); - session.sends2s(st.stanza("stream:stream", default_stream_attr):top_tag()); + session:open_stream(session.from_host, session.to_host); session.compressed = true; return true; end diff --git a/plugins/mod_dialback.lua b/plugins/mod_dialback.lua index b2f84603..9dcb0ed5 100644 --- a/plugins/mod_dialback.lua +++ b/plugins/mod_dialback.lua @@ -7,7 +7,6 @@ -- local hosts = _G.hosts; -local s2s_make_authenticated = require "core.s2smanager".make_authenticated; local log = module._log; @@ -110,7 +109,7 @@ module:hook("stanza/jabber:server:dialback:verify", function(event) if dialback_verifying and attr.from == origin.to_host then local valid; if attr.type == "valid" then - s2s_make_authenticated(dialback_verifying, attr.from); + module:fire_event("s2s-authenticated", { session = dialback_verifying, host = attr.from }); valid = "valid"; else -- Warn the original connection that is was not verified successfully @@ -146,7 +145,7 @@ module:hook("stanza/jabber:server:dialback:result", function(event) return true; end if stanza.attr.type == "valid" then - s2s_make_authenticated(origin, attr.from); + module:fire_event("s2s-authenticated", { session = origin, host = attr.from }); else origin:close("not-authorized", "dialback authentication failed"); end @@ -170,7 +169,7 @@ module:hook_stanza(xmlns_stream, "features", function (origin, stanza) end end, 100); -module:hook("s2s-authenticate-legacy", function (event) +module:hook("s2sout-authenticate-legacy", function (event) module:log("debug", "Initiating dialback..."); initiate_dialback(event.origin); return true; diff --git a/plugins/mod_groups.lua b/plugins/mod_groups.lua index 7a876f1d..f7f632c2 100644 --- a/plugins/mod_groups.lua +++ b/plugins/mod_groups.lua @@ -13,7 +13,7 @@ local members; local groups_file; local jid, datamanager = require "util.jid", require "util.datamanager"; -local jid_bare, jid_prep = jid.bare, jid.prep; +local jid_prep = jid.prep; local module_host = module:get_host(); @@ -80,7 +80,7 @@ function remove_virtual_contacts(username, host, datastore, data) end function module.load() - groups_file = config.get(module:get_host(), "core", "groups_file"); + groups_file = module:get_option_string("groups_file"); if not groups_file then return; end module:hook("roster-load", inject_roster_contacts); @@ -121,3 +121,8 @@ end function module.unload() datamanager.remove_callback(remove_virtual_contacts); end + +-- Public for other modules to access +function group_contains(group_name, jid) + return groups[group_name][jid]; +end diff --git a/plugins/mod_http.lua b/plugins/mod_http.lua index 018f2ea3..0689634e 100644 --- a/plugins/mod_http.lua +++ b/plugins/mod_http.lua @@ -9,6 +9,7 @@ module:set_global(); module:depends("http_errors"); +local portmanager = require "core.portmanager"; local moduleapi = require "core.moduleapi"; local url_parse = require "socket.url".parse; local url_build = require "socket.url".build; @@ -38,9 +39,10 @@ local function get_http_event(host, app_path, key) end local function get_base_path(host_module, app_name, default_app_path) - return normalize_path(host_module:get_option("http_paths", {})[app_name] -- Host + return (normalize_path(host_module:get_option("http_paths", {})[app_name] -- Host or module:get_option("http_paths", {})[app_name] -- Global - or default_app_path); -- Default + or default_app_path)) -- Default + :gsub("%$(%w+)", { host = module.host }); end local ports_by_scheme = { http = 80, https = 443, }; @@ -137,6 +139,7 @@ module:provides("net", { listener = server.listener; default_port = 5281; encryption = "ssl"; + ssl_config = { verify = "none" }; multiplex = { pattern = "^[A-Z]"; }; diff --git a/plugins/mod_http_errors.lua b/plugins/mod_http_errors.lua index 828216dd..2568ea80 100644 --- a/plugins/mod_http_errors.lua +++ b/plugins/mod_http_errors.lua @@ -2,7 +2,6 @@ module:set_global(); local server = require "net.http.server"; local codes = require "net.http.codes"; -local termcolours = require "util.termcolours"; local show_private = module:get_option_boolean("http_errors_detailed", false); local always_serve = module:get_option_boolean("http_errors_always_show", true); diff --git a/plugins/mod_iq.lua b/plugins/mod_iq.lua index 8044a533..e7901ab4 100644 --- a/plugins/mod_iq.lua +++ b/plugins/mod_iq.lua @@ -9,7 +9,7 @@ local st = require "util.stanza"; -local full_sessions = full_sessions; +local full_sessions = prosody.full_sessions; if module:get_host_type() == "local" then module:hook("iq/full", function(data) diff --git a/plugins/mod_message.lua b/plugins/mod_message.lua index 0b0ad8e4..e85da613 100644 --- a/plugins/mod_message.lua +++ b/plugins/mod_message.lua @@ -7,8 +7,8 @@ -- -local full_sessions = full_sessions; -local bare_sessions = bare_sessions; +local full_sessions = prosody.full_sessions; +local bare_sessions = prosody.bare_sessions; local st = require "util.stanza"; local jid_bare = require "util.jid".bare; diff --git a/plugins/mod_motd.lua b/plugins/mod_motd.lua index fea2cb85..ed78294b 100644 --- a/plugins/mod_motd.lua +++ b/plugins/mod_motd.lua @@ -13,7 +13,6 @@ local motd_jid = module:get_option_string("motd_jid", host); if not motd_text then return; end -local jid_join = require "util.jid".join; local st = require "util.stanza"; motd_text = motd_text:gsub("^%s*(.-)%s*$", "%1"):gsub("\n%s+", "\n"); -- Strip indentation from the config diff --git a/plugins/mod_posix.lua b/plugins/mod_posix.lua index e871e5cf..28fd7f38 100644 --- a/plugins/mod_posix.lua +++ b/plugins/mod_posix.lua @@ -7,10 +7,12 @@ -- -local want_pposix_version = "0.3.5"; +local want_pposix_version = "0.3.6"; local pposix = assert(require "util.pposix"); -if pposix._VERSION ~= want_pposix_version then module:log("warn", "Unknown version (%s) of binary pposix module, expected %s", tostring(pposix._VERSION), want_pposix_version); end +if pposix._VERSION ~= want_pposix_version then + module:log("warn", "Unknown version (%s) of binary pposix module, expected %s. Perhaps you need to recompile?", tostring(pposix._VERSION), want_pposix_version); +end local signal = select(2, pcall(require, "util.signal")); if type(signal) == "string" then @@ -118,9 +120,9 @@ function syslog_sink_maker(config) local syslog, format = pposix.syslog_log, string.format; return function (name, level, message, ...) if ... then - syslog(level, format(message, ...)); + syslog(level, name, format(message, ...)); else - syslog(level, message); + syslog(level, name, message); end end; end diff --git a/plugins/mod_presence.lua b/plugins/mod_presence.lua index 23012750..8dac2d35 100644 --- a/plugins/mod_presence.lua +++ b/plugins/mod_presence.lua @@ -9,7 +9,7 @@ local log = module._log; local require = require; -local pairs, ipairs = pairs, ipairs; +local pairs = pairs; local t_concat, t_insert = table.concat, table.insert; local s_find = string.find; local tonumber = tonumber; @@ -19,7 +19,9 @@ local st = require "util.stanza"; local jid_split = require "util.jid".split; local jid_bare = require "util.jid".bare; local datetime = require "util.datetime"; -local hosts = hosts; +local hosts = prosody.hosts; +local bare_sessions = prosody.bare_sessions; +local full_sessions = prosody.full_sessions; local NULL = {}; local rostermanager = require "core.rostermanager"; @@ -344,7 +346,7 @@ module:hook("presence/full", function(data) end); module:hook("presence/host", function(data) -- inbound presence to the host - local origin, stanza = data.origin, data.stanza; + local stanza = data.stanza; local from_bare = jid_bare(stanza.attr.from); local t = stanza.attr.type; diff --git a/plugins/mod_privacy.lua b/plugins/mod_privacy.lua index 2d696154..31ace9f9 100644 --- a/plugins/mod_privacy.lua +++ b/plugins/mod_privacy.lua @@ -9,16 +9,16 @@ module:add_feature("jabber:iq:privacy"); -local prosody = prosody; local st = require "util.stanza"; -local datamanager = require "util.datamanager"; -local bare_sessions, full_sessions = bare_sessions, full_sessions; +local bare_sessions, full_sessions = prosody.bare_sessions, prosody.full_sessions; local util_Jid = require "util.jid"; local jid_bare = util_Jid.bare; local jid_split, jid_join = util_Jid.split, util_Jid.join; local load_roster = require "core.rostermanager".load_roster; local to_number = tonumber; +local privacy_storage = module:open_store(); + function isListUsed(origin, name, privacy_lists) local user = bare_sessions[origin.username.."@"..origin.host]; if user then @@ -218,7 +218,7 @@ module:hook("iq/bare/jabber:iq:privacy:query", function(data) if stanza.attr.to == nil then -- only service requests to own bare JID local query = stanza.tags[1]; -- the query element local valid = false; - local privacy_lists = datamanager.load(origin.username, origin.host, "privacy") or { lists = {} }; + local privacy_lists = privacy_storage:get(origin.username) or { lists = {} }; if privacy_lists.lists[1] then -- Code to migrate from old privacy lists format, remove in 0.8 module:log("info", "Upgrading format of stored privacy lists for %s@%s", origin.username, origin.host); @@ -273,7 +273,7 @@ module:hook("iq/bare/jabber:iq:privacy:query", function(data) end origin.send(st.error_reply(stanza, valid[1], valid[2], valid[3])); else - datamanager.store(origin.username, origin.host, "privacy", privacy_lists); + privacy_storage:set(origin.username, privacy_lists); end return true; end @@ -281,7 +281,7 @@ end); function checkIfNeedToBeBlocked(e, session) local origin, stanza = e.origin, e.stanza; - local privacy_lists = datamanager.load(session.username, session.host, "privacy") or {}; + local privacy_lists = privacy_storage:get(session.username) or {}; local bare_jid = session.username.."@"..session.host; local to = stanza.attr.to or bare_jid; local from = stanza.attr.from; @@ -367,6 +367,10 @@ function checkIfNeedToBeBlocked(e, session) end if apply then if block then + -- drop and not bounce groupchat messages, otherwise users will get kicked + if stanza.attr.type == "groupchat" then + return true; + end module:log("debug", "stanza blocked: %s, to: %s, from: %s", tostring(stanza.name), tostring(to), tostring(from)); if stanza.name == "message" then origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); diff --git a/plugins/mod_private.lua b/plugins/mod_private.lua index f1ebe786..365a997c 100644 --- a/plugins/mod_private.lua +++ b/plugins/mod_private.lua @@ -9,8 +9,7 @@ local st = require "util.stanza" -local jid_split = require "util.jid".split; -local datamanager = require "util.datamanager" +local private_storage = module:open_store(); module:add_feature("jabber:iq:private"); @@ -21,7 +20,7 @@ module:hook("iq/self/jabber:iq:private:query", function(event) if #query.tags == 1 then local tag = query.tags[1]; local key = tag.name..":"..tag.attr.xmlns; - local data, err = datamanager.load(origin.username, origin.host, "private"); + local data, err = private_storage:get(origin.username); if err then origin.send(st.error_reply(stanza, "wait", "internal-server-error")); return true; @@ -40,7 +39,7 @@ module:hook("iq/self/jabber:iq:private:query", function(event) data[key] = st.preserialize(tag); end -- TODO delete datastore if empty - if datamanager.store(origin.username, origin.host, "private", data) then + if private_storage:set(origin.username, data) then origin.send(st.reply(stanza)); else origin.send(st.error_reply(stanza, "wait", "internal-server-error")); diff --git a/plugins/mod_proxy65.lua b/plugins/mod_proxy65.lua index d6e41604..1fa42bd8 100644 --- a/plugins/mod_proxy65.lua +++ b/plugins/mod_proxy65.lua @@ -95,7 +95,7 @@ function module.add_host(module) local proxy_port = next(portmanager.get_active_services():search("proxy65", nil)[1] or {}); local proxy_acl = module:get_option("proxy65_acl"); - -- COMPAT w/pre-0.9 where proxy65_port was specified the components section of the config + -- COMPAT w/pre-0.9 where proxy65_port was specified in the components section of the config local legacy_config = module:get_option_number("proxy65_port"); if legacy_config then module:log("warn", "proxy65_port is deprecated, please put proxy65_ports = { %d } into the global section instead", legacy_config); @@ -106,16 +106,20 @@ function module.add_host(module) module:hook("iq-get/host/http://jabber.org/protocol/disco#info:query", function(event) local origin, stanza = event.origin, event.stanza; - origin.send(st.reply(stanza):query("http://jabber.org/protocol/disco#info") - :tag("identity", {category='proxy', type='bytestreams', name=name}):up() - :tag("feature", {var="http://jabber.org/protocol/bytestreams"}) ); - return true; + if not stanza.tags[1].attr.node then + origin.send(st.reply(stanza):query("http://jabber.org/protocol/disco#info") + :tag("identity", {category='proxy', type='bytestreams', name=name}):up() + :tag("feature", {var="http://jabber.org/protocol/bytestreams"}) ); + return true; + end end, -1); module:hook("iq-get/host/http://jabber.org/protocol/disco#items:query", function(event) local origin, stanza = event.origin, event.stanza; - origin.send(st.reply(stanza):query("http://jabber.org/protocol/disco#items")); - return true; + if not stanza.tags[1].attr.node then + origin.send(st.reply(stanza):query("http://jabber.org/protocol/disco#items")); + return true; + end end, -1); module:hook("iq-get/host/http://jabber.org/protocol/bytestreams:query", function(event) diff --git a/plugins/mod_pubsub.lua b/plugins/mod_pubsub.lua index fe6c0b0a..22969ab5 100644 --- a/plugins/mod_pubsub.lua +++ b/plugins/mod_pubsub.lua @@ -22,6 +22,9 @@ function handle_pubsub_iq(event) local origin, stanza = event.origin, event.stanza; local pubsub = stanza.tags[1]; local action = pubsub.tags[1]; + if not action then + return origin.send(st.error_reply(stanza, "cancel", "bad-request")); + end local handler = handlers[stanza.attr.type.."_"..action.name]; if handler then handler(origin, stanza, action); @@ -164,16 +167,6 @@ function handlers.set_subscribe(origin, stanza, subscribe) reply = pubsub_error_reply(stanza, ret); end origin.send(reply); - if ok then - -- Send all current items - local ok, items = service:get_items(node, stanza.attr.from); - if items then - local jids = { [jid] = options or true }; - for id, item in pairs(items) do - service.config.broadcaster("items", node, jids, item); - end - end - end end function handlers.set_unsubscribe(origin, stanza, unsubscribe) @@ -197,7 +190,13 @@ function handlers.set_publish(origin, stanza, publish) return origin.send(pubsub_error_reply(stanza, "nodeid-required")); end local item = publish:get_child("item"); - local id = (item and item.attr.id) or uuid_generate(); + local id = (item and item.attr.id); + if not id then + id = uuid_generate(); + if item then + item.attr.id = id; + end + end local ok, ret = service:publish(node, stanza.attr.from, id, item); local reply; if ok then diff --git a/plugins/mod_register.lua b/plugins/mod_register.lua index b3abd394..141a4997 100644 --- a/plugins/mod_register.lua +++ b/plugins/mod_register.lua @@ -7,9 +7,7 @@ -- -local hosts = _G.hosts; local st = require "util.stanza"; -local datamanager = require "util.datamanager"; local dataform_new = require "util.dataforms".new; local usermanager_user_exists = require "core.usermanager".user_exists; local usermanager_create_user = require "core.usermanager".create_user; @@ -23,6 +21,8 @@ local compat = module:get_option_boolean("registration_compat", true); local allow_registration = module:get_option_boolean("allow_registration", false); local additional_fields = module:get_option("additional_registration_fields", {}); +local account_details = module:open_store("account_details"); + local field_map = { username = { name = "username", type = "text-single", label = "Username", required = true }; password = { name = "password", type = "text-private", label = "Password", required = true }; @@ -235,7 +235,7 @@ module:hook("stanza/iq/jabber:iq:register:query", function(event) -- TODO unable to write file, file may be locked, etc, what's the correct error? local error_reply = st.error_reply(stanza, "wait", "internal-server-error", "Failed to write data to disk."); if usermanager_create_user(username, password, host) then - if next(data) and not datamanager.store(username, host, "account_details", data) then + if next(data) and not account_details:set(username, data) then usermanager_delete_user(username, host); session.send(error_reply); return true; diff --git a/plugins/mod_roster.lua b/plugins/mod_roster.lua index 40d95be7..d530bb45 100644 --- a/plugins/mod_roster.lua +++ b/plugins/mod_roster.lua @@ -69,7 +69,6 @@ module:hook("iq/self/jabber:iq:roster:query", function(event) and query.tags[1].attr.jid ~= "pending" then local item = query.tags[1]; local from_node, from_host = jid_split(stanza.attr.from); - local from_bare = from_node and (from_node.."@"..from_host) or from_host; -- bare JID local jid = jid_prep(item.attr.jid); local node, host, resource = jid_split(jid); if not resource and host then diff --git a/plugins/mod_s2s/mod_s2s.lua b/plugins/mod_s2s/mod_s2s.lua index 15c89ced..30ebb706 100644 --- a/plugins/mod_s2s/mod_s2s.lua +++ b/plugins/mod_s2s/mod_s2s.lua @@ -15,6 +15,7 @@ local core_process_stanza = prosody.core_process_stanza; local tostring, type = tostring, type; local t_insert = table.insert; local xpcall, traceback = xpcall, debug.traceback; +local NULL = {}; local add_task = require "util.timer".add_task; local st = require "util.stanza"; @@ -24,14 +25,19 @@ local new_xmpp_stream = require "util.xmppstream".new; local s2s_new_incoming = require "core.s2smanager".new_incoming; local s2s_new_outgoing = require "core.s2smanager".new_outgoing; local s2s_destroy_session = require "core.s2smanager".destroy_session; -local s2s_mark_connected = require "core.s2smanager".mark_connected; local uuid_gen = require "util.uuid".generate; local cert_verify_identity = require "util.x509".verify_identity; +local fire_global_event = prosody.events.fire_event; local s2sout = module:require("s2sout"); local connect_timeout = module:get_option_number("s2s_timeout", 90); local stream_close_timeout = module:get_option_number("s2s_close_timeout", 5); +local opt_keepalives = module:get_option_boolean("s2s_tcp_keepalives", module:get_option_boolean("tcp_keepalives", true)); +local secure_auth = module:get_option_boolean("s2s_secure_auth", false); -- One day... +local secure_domains, insecure_domains = + module:get_option_set("s2s_secure_domains", {})._items, module:get_option_set("s2s_insecure_domains", {})._items; +local require_encryption = module:get_option_boolean("s2s_require_encryption", secure_auth); local sessions = module:shared("sessions"); @@ -75,6 +81,10 @@ function route_to_existing_session(event) log("warn", "Attempt to send stanza from %s - a host we don't serve", from_host); return false; end + if hosts[to_host] then + log("warn", "Attempt to route stanza to a remote %s - a host we do serve?!", from_host); + return false; + end local host = hosts[from_host].s2sout[to_host]; if host then -- We have a connection to this host already @@ -130,12 +140,86 @@ function module.add_host(module) module:log("warn", "The 'disallow_s2s' config option is deprecated, please see http://prosody.im/doc/s2s#disabling"); return nil, "This host has disallow_s2s set"; end - module:hook("route/remote", route_to_existing_session, 200); - module:hook("route/remote", route_to_new_session, 100); + module:hook("route/remote", route_to_existing_session, -1); + module:hook("route/remote", route_to_new_session, -10); + module:hook("s2s-authenticated", make_authenticated, -1); +end + +-- Stream is authorised, and ready for normal stanzas +function mark_connected(session) + local sendq, send = session.sendq, session.sends2s; + + local from, to = session.from_host, session.to_host; + + session.log("info", "%s s2s connection %s->%s complete", session.direction, from, to); + + local event_data = { session = session }; + if session.type == "s2sout" then + fire_global_event("s2sout-established", event_data); + hosts[from].events.fire_event("s2sout-established", event_data); + else + local host_session = hosts[to]; + session.send = function(stanza) + return host_session.events.fire_event("route/remote", { from_host = to, to_host = from, stanza = stanza }); + end; + + fire_global_event("s2sin-established", event_data); + hosts[to].events.fire_event("s2sin-established", event_data); + end + + if session.direction == "outgoing" then + if sendq then + session.log("debug", "sending %d queued stanzas across new outgoing connection to %s", #sendq, session.to_host); + for i, data in ipairs(sendq) do + send(data[1]); + sendq[i] = nil; + end + session.sendq = nil; + end + + session.ip_hosts = nil; + session.srv_hosts = nil; + end +end + +function make_authenticated(event) + local session, host = event.session, event.host; + if not session.secure then + if require_encryption or secure_auth or secure_domains[host] then + session:close({ + condition = "policy-violation", + text = "Encrypted server-to-server communication is required but was not " + ..((session.direction == "outgoing" and "offered") or "used") + }); + end + end + if hosts[host] then + session:close({ condition = "undefined-condition", text = "Attempt to authenticate as a host we serve" }); + end + if session.type == "s2sout_unauthed" then + session.type = "s2sout"; + elseif session.type == "s2sin_unauthed" then + session.type = "s2sin"; + if host then + if not session.hosts[host] then session.hosts[host] = {}; end + session.hosts[host].authed = true; + end + elseif session.type == "s2sin" and host then + if not session.hosts[host] then session.hosts[host] = {}; end + session.hosts[host].authed = true; + else + return false; + end + session.log("debug", "connection %s->%s is now authenticated for %s", session.from_host, session.to_host, host); + + mark_connected(session); + + return true; end --- Helper to check that a session peer's certificate is valid local function check_cert_status(session) + local host = session.direction == "outgoing" and session.to_host or session.from_host local conn = session.conn:socket() local cert if conn.getpeercertificate then @@ -143,11 +227,19 @@ local function check_cert_status(session) end if cert then - local chain_valid, errors = conn:getpeerverification() + local chain_valid, errors; + if conn.getpeerverification then + chain_valid, errors = conn:getpeerverification(); + elseif conn.getpeerchainvalid then -- COMPAT mw/luasec-hg + chain_valid, errors = conn:getpeerchainvalid(); + errors = (not chain_valid) and { { errors } } or nil; + else + chain_valid, errors = false, { { "Chain verification not supported by this version of LuaSec" } }; + end -- Is there any interest in printing out all/the number of errors here? if not chain_valid then (session.log or log)("debug", "certificate chain validation result: invalid"); - for depth, t in ipairs(errors) do + for depth, t in ipairs(errors or NULL) do (session.log or log)("debug", "certificate error(s) at depth %d: %s", depth-1, table.concat(t, ", ")) end session.cert_chain_status = "invalid"; @@ -155,8 +247,6 @@ local function check_cert_status(session) (session.log or log)("debug", "certificate chain validation result: valid"); session.cert_chain_status = "valid"; - local host = session.direction == "incoming" and session.from_host or session.to_host - -- We'll go ahead and verify the asserted identity if the -- connecting server specified one. if host then @@ -168,6 +258,7 @@ local function check_cert_status(session) end end end + return module:fire_event("s2s-check-certificate", { host = host, session = session, cert = cert }); end --- XMPP stream event handlers @@ -246,11 +337,18 @@ function stream_callbacks.streamopened(session, attr) end end - if session.secure and not session.cert_chain_status then check_cert_status(session); end + if hosts[from] then + session:close({ condition = "undefined-condition", text = "Attempt to connect from a host we serve" }); + return; + end + + if session.secure and not session.cert_chain_status then + if check_cert_status(session) == false then + return; + end + end - send("<?xml version='1.0'?>"); - send(st.stanza("stream:stream", { xmlns='jabber:server', ["xmlns:db"]='jabber:server:dialback', - ["xmlns:stream"]='http://etherx.jabber.org/streams', id=session.streamid, from=to, to=from, version=(session.version > 0 and "1.0" or nil) }):top_tag()); + session:open_stream(session.to_host, session.from_host) if session.version >= 1.0 then local features = st.stanza("stream:features"); @@ -268,7 +366,11 @@ function stream_callbacks.streamopened(session, attr) if not attr.id then error("stream response did not give us a streamid!!!"); end session.streamid = attr.id; - if session.secure and not session.cert_chain_status then check_cert_status(session); end + if session.secure and not session.cert_chain_status then + if check_cert_status(session) == false then + return; + end + end -- Send unauthed buffer -- (stanzas which are fine to send before dialback) @@ -287,9 +389,9 @@ function stream_callbacks.streamopened(session, attr) -- If server is pre-1.0, don't wait for features, just do dialback if session.version < 1.0 then if not session.dialback_verifying then - hosts[session.from_host].events.fire_event("s2s-authenticate-legacy", { origin = session }); + hosts[session.from_host].events.fire_event("s2sout-authenticate-legacy", { origin = session }); else - s2s_mark_connected(session); + mark_connected(session); end end end @@ -327,7 +429,7 @@ function stream_callbacks.error(session, error, data) end end -local function handleerr(err) log("error", "Traceback[s2s]: %s: %s", tostring(err), traceback()); end +local function handleerr(err) log("error", "Traceback[s2s]: %s", traceback(tostring(err), 2)); end function stream_callbacks.handlestanza(session, stanza) if stanza.attr.xmlns == "jabber:client" then --COMPAT: Prosody pre-0.6.2 may send jabber:client stanza.attr.xmlns = nil; @@ -342,13 +444,15 @@ local listener = {}; --- Session methods local stream_xmlns_attr = {xmlns='urn:ietf:params:xml:ns:xmpp-streams'}; -local default_stream_attr = { ["xmlns:stream"] = "http://etherx.jabber.org/streams", xmlns = stream_callbacks.default_ns, version = "1.0", id = "" }; local function session_close(session, reason, remote_reason) local log = session.log or log; if session.conn then if session.notopen then - session.sends2s("<?xml version='1.0'?>"); - session.sends2s(st.stanza("stream:stream", default_stream_attr):top_tag()); + if session.direction == "incoming" then + session:open_stream(session.to_host, session.from_host); + else + session:open_stream(session.from_host, session.to_host); + end end if reason then -- nil == no err, initiated by us, false == initiated by remote if type(reason) == "string" then -- assume stream error @@ -395,6 +499,24 @@ local function session_close(session, reason, remote_reason) end end +function session_open_stream(session, from, to) + local attr = { + ["xmlns:stream"] = 'http://etherx.jabber.org/streams', + xmlns = 'jabber:server', + version = session.version and (session.version > 0 and "1.0" or nil), + ["xml:lang"] = 'en', + id = session.streamid, + from = from, to = to, + } + if not from or (hosts[from] and hosts[from].modules.dialback) then + attr["xmlns:db"] = 'jabber:server:dialback'; + end + + session.sends2s("<?xml version='1.0'?>"); + session.sends2s(st.stanza("stream:stream", attr):top_tag()); + return true; +end + -- Session initialization logic shared by incoming and outgoing local function initialize_session(session) local stream = new_xmpp_stream(session, stream_callbacks); @@ -406,6 +528,8 @@ local function initialize_session(session) session.notopen = true; session.stream:reset(); end + + session.open_stream = session_open_stream; local filter = session.filter; function session.data(data) @@ -440,6 +564,7 @@ local function initialize_session(session) end function listener.onconnect(conn) + conn:setoption("keepalive", opt_keepalives); local session = sessions[conn]; if not session then -- New incoming connection session = s2s_new_incoming(conn); @@ -506,6 +631,29 @@ function listener.register_outgoing(conn, session) initialize_session(session); end +function check_auth_policy(event) + local host, session = event.host, event.session; + local must_secure = secure_auth; + + if not must_secure and secure_domains[host] then + must_secure = true; + elseif must_secure and insecure_domains[host] then + must_secure = false; + end + + if must_secure and not session.cert_identity_status then + module:log("warn", "Forbidding insecure connection to/from %s", host); + if session.direction == "incoming" then + session:close({ condition = "not-authorized", text = "Your server's certificate is invalid, expired, or not trusted by "..session.to_host }); + else -- Close outgoing connections without warning + session:close(false); + end + return false; + end +end + +module:hook("s2s-check-certificate", check_auth_policy, -1); + s2sout.set_listener(listener); module:hook("server-stopping", function(event) diff --git a/plugins/mod_s2s/s2sout.lib.lua b/plugins/mod_s2s/s2sout.lib.lua index 07623968..cb2f8be4 100644 --- a/plugins/mod_s2s/s2sout.lib.lua +++ b/plugins/mod_s2s/s2sout.lib.lua @@ -13,7 +13,7 @@ local wrapclient = require "net.server".wrapclient; local initialize_filters = require "util.filters".initialize; local idna_to_ascii = require "util.encodings".idna.to_ascii; local new_ip = require "util.ip".new_ip; -local rfc3484_dest = require "util.rfc3484".destination; +local rfc6724_dest = require "util.rfc6724".destination; local socket = require "socket"; local adns = require "net.adns"; local dns = require "net.dns"; @@ -44,15 +44,9 @@ local function compare_srv_priorities(a,b) return a.priority < b.priority or (a.priority == b.priority and a.weight > b.weight); end -local function session_open_stream(session, from, to) - session.sends2s(st.stanza("stream:stream", { - xmlns='jabber:server', ["xmlns:db"]='jabber:server:dialback', - ["xmlns:stream"]='http://etherx.jabber.org/streams', - from=from, to=to, version='1.0', ["xml:lang"]='en'}):top_tag()); -end - function s2sout.initiate_connection(host_session) initialize_filters(host_session); + host_session.version = 1; host_session.open_stream = session_open_stream; -- Kick the connection attempting machine into life @@ -96,7 +90,7 @@ function s2sout.attempt_connection(host_session, err) host_session.connecting = nil; if answer and #answer > 0 then log("debug", "%s has SRV records, handling...", to_host); - local srv_hosts = {}; + local srv_hosts = { answer = answer }; host_session.srv_hosts = srv_hosts; for _, record in ipairs(answer) do t_insert(srv_hosts, record.srv); @@ -197,7 +191,7 @@ function s2sout.try_connect(host_session, connect_host, connect_port, err) if have_other_result then if #IPs > 0 then - rfc3484_dest(host_session.ip_hosts, sources); + rfc6724_dest(host_session.ip_hosts, sources); for i = 1, #IPs do IPs[i] = {ip = IPs[i], port = connect_port}; end @@ -233,7 +227,7 @@ function s2sout.try_connect(host_session, connect_host, connect_port, err) if have_other_result then if #IPs > 0 then - rfc3484_dest(host_session.ip_hosts, sources); + rfc6724_dest(host_session.ip_hosts, sources); for i = 1, #IPs do IPs[i] = {ip = IPs[i], port = connect_port}; end @@ -277,6 +271,10 @@ function s2sout.make_connect(host_session, connect_host, connect_port) local from_host, to_host = host_session.from_host, host_session.to_host; + -- Reset secure flag in case this is another + -- connection attempt after a failed STARTTLS + host_session.secure = nil; + local conn, handler; if connect_host.proto == "IPv4" then conn, handler = socket.tcp(); diff --git a/plugins/mod_saslauth.lua b/plugins/mod_saslauth.lua index f6abd3b8..201cc477 100644 --- a/plugins/mod_saslauth.lua +++ b/plugins/mod_saslauth.lua @@ -11,7 +11,6 @@ local st = require "util.stanza"; local sm_bind_resource = require "core.sessionmanager".bind_resource; local sm_make_authenticated = require "core.sessionmanager".make_authenticated; -local s2s_make_authenticated = require "core.s2smanager".make_authenticated; local base64 = require "util.encodings".base64; local cert_verify_identity = require "util.x509".verify_identity; @@ -88,13 +87,9 @@ module:hook_stanza(xmlns_sasl, "success", function (session, stanza) module:log("debug", "SASL EXTERNAL with %s succeeded", session.to_host); session.external_auth = "succeeded" session:reset_stream(); + session:open_stream(session.from_host, session.to_host); - local default_stream_attr = {xmlns = "jabber:server", ["xmlns:stream"] = "http://etherx.jabber.org/streams", - ["xmlns:db"] = 'jabber:server:dialback', version = "1.0", to = session.to_host, from = session.from_host}; - session.sends2s("<?xml version='1.0'?>"); - session.sends2s(st.stanza("stream:stream", default_stream_attr):top_tag()); - - s2s_make_authenticated(session, session.to_host); + module:fire_event("s2s-authenticated", { session = session, host = session.to_host }); return true; end) @@ -191,7 +186,7 @@ local function s2s_external_auth(session, stanza) local domain = text ~= "" and text or session.from_host; module:log("info", "Accepting SASL EXTERNAL identity from %s", domain); - s2s_make_authenticated(session, domain); + module:fire_event("s2s-authenticated", { session = session, host = domain }); session:reset_stream(); return true end diff --git a/plugins/mod_storage_none.lua b/plugins/mod_storage_none.lua new file mode 100644 index 00000000..8f2d2f56 --- /dev/null +++ b/plugins/mod_storage_none.lua @@ -0,0 +1,23 @@ +local driver = {}; +local driver_mt = { __index = driver }; + +function driver:open(store) + return setmetatable({ store = store }, driver_mt); +end +function driver:get(user) + return {}; +end + +function driver:set(user, data) + return nil, "Storage disabled"; +end + +function driver:stores(username) + return { "roster" }; +end + +function driver:purge(user) + return true; +end + +module:provides("storage", driver); diff --git a/plugins/mod_storage_sql2.lua b/plugins/mod_storage_sql2.lua new file mode 100644 index 00000000..7d705b0b --- /dev/null +++ b/plugins/mod_storage_sql2.lua @@ -0,0 +1,237 @@ + +local json = require "util.json"; +local resolve_relative_path = require "core.configmanager".resolve_relative_path; + +local mod_sql = module:require("sql"); +local params = module:get_option("sql"); + +local engine; -- TODO create engine + +local function create_table() + --[[local Table,Column,Index = mod_sql.Table,mod_sql.Column,mod_sql.Index; + local ProsodyTable = Table { + name="prosody"; + Column { name="host", type="TEXT", nullable=false }; + Column { name="user", type="TEXT", nullable=false }; + Column { name="store", type="TEXT", nullable=false }; + Column { name="key", type="TEXT", nullable=false }; + Column { name="type", type="TEXT", nullable=false }; + Column { name="value", type="TEXT", nullable=false }; + Index { name="prosody_index", "host", "user", "store", "key" }; + }; + engine:transaction(function() + ProsodyTable:create(engine); + end);]] + if not module:get_option("sql_manage_tables", true) then + return; + end + + local create_sql = "CREATE TABLE `prosody` (`host` TEXT, `user` TEXT, `store` TEXT, `key` TEXT, `type` TEXT, `value` TEXT);"; + if params.driver == "PostgreSQL" then + create_sql = create_sql:gsub("`", "\""); + elseif params.driver == "MySQL" then + create_sql = create_sql:gsub("`value` TEXT", "`value` MEDIUMTEXT") + :gsub(";$", " CHARACTER SET 'utf8' COLLATE 'utf8_bin';"); + end + + local index_sql = "CREATE INDEX `prosody_index` ON `prosody` (`host`, `user`, `store`, `key`)"; + if params.driver == "PostgreSQL" then + index_sql = index_sql:gsub("`", "\""); + elseif params.driver == "MySQL" then + index_sql = index_sql:gsub("`([,)])", "`(20)%1"); + end + + local success,err = engine:transaction(function() + engine:execute(create_sql); + engine:execute(index_sql); + end); + if not success then -- so we failed to create + if params.driver == "MySQL" then + success,err = engine:transaction(function() + local result = engine:execute("SHOW COLUMNS FROM prosody WHERE Field='value' and Type='text'"); + if result:rowcount() > 0 then + module:log("info", "Upgrading database schema..."); + engine:execute("ALTER TABLE prosody MODIFY COLUMN `value` MEDIUMTEXT"); + module:log("info", "Database table automatically upgraded"); + end + return true; + end); + if not success then + module:log("error", "Failed to check/upgrade database schema (%s), please see " + .."http://prosody.im/doc/mysql for help", + err or "unknown error"); + end + end + end +end +local function set_encoding() + if params.driver ~= "SQLite3" then + local set_names_query = "SET NAMES 'utf8';"; + if params.driver == "MySQL" then + set_names_query = set_names_query:gsub(";$", " COLLATE 'utf8_bin';"); + end + local success,err = engine:transaction(function() return engine:execute(set_names_query); end); + if not success then + module:log("error", "Failed to set database connection encoding to UTF8: %s", err); + return; + end + if params.driver == "MySQL" then + -- COMPAT w/pre-0.9: Upgrade tables to UTF-8 if not already + local check_encoding_query = "SELECT `COLUMN_NAME`,`COLUMN_TYPE` FROM `information_schema`.`columns` WHERE `TABLE_NAME`='prosody' AND ( `CHARACTER_SET_NAME`!='utf8' OR `COLLATION_NAME`!='utf8_bin' );"; + local success,err = engine:transaction(function() + local result = engine:execute(check_encoding_query); + local n_bad_columns = result:rowcount(); + if n_bad_columns > 0 then + module:log("warn", "Found %d columns in prosody table requiring encoding change, updating now...", n_bad_columns); + local fix_column_query1 = "ALTER TABLE `prosody` CHANGE `%s` `%s` BLOB;"; + local fix_column_query2 = "ALTER TABLE `prosody` CHANGE `%s` `%s` %s CHARACTER SET 'utf8' COLLATE 'utf8_bin';"; + for row in success:rows() do + local column_name, column_type = unpack(row); + engine:execute(fix_column_query1:format(column_name, column_name)); + engine:execute(fix_column_query2:format(column_name, column_name, column_type)); + end + module:log("info", "Database encoding upgrade complete!"); + end + end); + local success,err = engine:transaction(function() return engine:execute(check_encoding_query); end); + if not success then + module:log("error", "Failed to check/upgrade database encoding: %s", err or "unknown error"); + end + end + end +end + +do -- process options to get a db connection + params = params or { driver = "SQLite3" }; + + if params.driver == "SQLite3" then + params.database = resolve_relative_path(prosody.paths.data or ".", params.database or "prosody.sqlite"); + end + + assert(params.driver and params.database, "Both the SQL driver and the database need to be specified"); + + --local dburi = db2uri(params); + engine = mod_sql:create_engine(params); + + -- Encoding mess + set_encoding(); + + -- Automatically create table, ignore failure (table probably already exists) + create_table(); +end + +local function serialize(value) + local t = type(value); + if t == "string" or t == "boolean" or t == "number" then + return t, tostring(value); + elseif t == "table" then + local value,err = json.encode(value); + if value then return "json", value; end + return nil, err; + end + return nil, "Unhandled value type: "..t; +end +local function deserialize(t, value) + if t == "string" then return value; + elseif t == "boolean" then + if value == "true" then return true; + elseif value == "false" then return false; end + elseif t == "number" then return tonumber(value); + elseif t == "json" then + return json.decode(value); + end +end + +local host = module.host; +local user, store; + +local function keyval_store_get() + local haveany; + local result = {}; + for row in engine:select("SELECT `key`,`type`,`value` FROM `prosody` WHERE `host`=? AND `user`=? AND `store`=?", host, user, store) do + haveany = true; + local k = row[1]; + local v = deserialize(row[2], row[3]); + if k and v then + if k ~= "" then result[k] = v; elseif type(v) == "table" then + for a,b in pairs(v) do + result[a] = b; + end + end + end + end + if haveany then + return result; + end +end +local function keyval_store_set(data) + engine:delete("DELETE FROM `prosody` WHERE `host`=? AND `user`=? AND `store`=?", host, user, store); + + if data and next(data) ~= nil then + local extradata = {}; + for key, value in pairs(data) do + if type(key) == "string" and key ~= "" then + local t, value = serialize(value); + assert(t, value); + engine:insert("INSERT INTO `prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)", host, user, store, key, t, value); + else + extradata[key] = value; + end + end + if next(extradata) ~= nil then + local t, extradata = serialize(extradata); + assert(t, extradata); + engine:insert("INSERT INTO `prosody` (`host`,`user`,`store`,`key`,`type`,`value`) VALUES (?,?,?,?,?,?)", host, user, store, "", t, extradata); + end + end + return true; +end + +local keyval_store = {}; +keyval_store.__index = keyval_store; +function keyval_store:get(username) + user,store = username,self.store; + return select(2, engine:transaction(keyval_store_get)); +end +function keyval_store:set(username, data) + user,store = username,self.store; + return engine:transaction(function() + return keyval_store_set(data); + end); +end +function keyval_store:users() + return engine:transaction(function() + return engine:select("SELECT DISTINCT `user` FROM `prosody` WHERE `host`=? AND `store`=?", host, self.store); + end); +end + +local driver = {}; + +function driver:open(store, typ) + if not typ then -- default key-value store + return setmetatable({ store = store }, keyval_store); + end + return nil, "unsupported-store"; +end + +function driver:stores(username) + local sql = "SELECT DISTINCT `store` FROM `prosody` WHERE `host`=? AND `user`" .. + (username == true and "!=?" or "=?"); + if username == true or not username then + username = ""; + end + return engine:transaction(function() + return engine:select(sql, host, username); + end); +end + +function driver:purge(username) + return engine:transaction(function() + local stmt,err = engine:delete("DELETE FROM `prosody` WHERE `host`=? AND `user`=?", host, username); + return true,err; + end); +end + +module:provides("storage", driver); + + diff --git a/plugins/mod_tls.lua b/plugins/mod_tls.lua index 707ae8f5..80b56abb 100644 --- a/plugins/mod_tls.lua +++ b/plugins/mod_tls.lua @@ -25,6 +25,7 @@ if secure_s2s_only then s2s_feature:tag("required"):up(); end local global_ssl_ctx = prosody.global_ssl_ctx; +local hosts = prosody.hosts; local host = hosts[module.host]; local function can_do_tls(session) @@ -91,10 +92,10 @@ module:hook_stanza(xmlns_starttls, "proceed", function (session, stanza) end); function module.load() - local ssl_config = config.rawget(module.host, "core", "ssl"); + local ssl_config = config.rawget(module.host, "ssl"); if not ssl_config then local base_host = module.host:match("%.(.*)"); - ssl_config = config.get(base_host, "core", "ssl"); + ssl_config = config.get(base_host, "ssl"); end host.ssl_ctx = create_context(host.host, "client", ssl_config); -- for outgoing connections host.ssl_ctx_in = create_context(host.host, "server", ssl_config); -- for incoming connections diff --git a/plugins/mod_vcard.lua b/plugins/mod_vcard.lua index d3c27cc0..26b30e3a 100644 --- a/plugins/mod_vcard.lua +++ b/plugins/mod_vcard.lua @@ -8,7 +8,8 @@ local st = require "util.stanza" local jid_split = require "util.jid".split; -local datamanager = require "util.datamanager" + +local vcards = module:open_store(); module:add_feature("vcard-temp"); @@ -19,9 +20,9 @@ local function handle_vcard(event) local vCard; if to then local node, host = jid_split(to); - vCard = st.deserialize(datamanager.load(node, host, "vcard")); -- load vCard for user or server + vCard = st.deserialize(vcards:get(node)); -- load vCard for user or server else - vCard = st.deserialize(datamanager.load(session.username, session.host, "vcard"));-- load user's own vCard + vCard = st.deserialize(vcards:get(session.username));-- load user's own vCard end if vCard then session.send(st.reply(stanza):add_child(vCard)); -- send vCard! @@ -30,7 +31,7 @@ local function handle_vcard(event) end else if not to then - if datamanager.store(session.username, session.host, "vcard", st.preserialize(stanza.tags[1])) then + if vcards:set(session.username, st.preserialize(stanza.tags[1])) then session.send(st.reply(stanza)); else -- TODO unable to write file, file may be locked, etc, what's the correct error? diff --git a/plugins/muc/mod_muc.lua b/plugins/muc/mod_muc.lua index 0df8b790..7861092c 100644 --- a/plugins/muc/mod_muc.lua +++ b/plugins/muc/mod_muc.lua @@ -28,13 +28,14 @@ local jid_split = require "util.jid".split; local jid_bare = require "util.jid".bare; local st = require "util.stanza"; local uuid_gen = require "util.uuid".generate; -local datamanager = require "util.datamanager"; local um_is_admin = require "core.usermanager".is_admin; -local hosts = hosts; +local hosts = prosody.hosts; rooms = {}; local rooms = rooms; -local persistent_rooms = datamanager.load(nil, muc_host, "persistent") or {}; +local persistent_rooms_storage = module:open_store("persistent"); +local persistent_rooms = persistent_rooms_storage:get() or {}; +local room_configs = module:open_store("config"); -- Configurable options muclib.set_max_history_length(module:get_option_number("max_history_messages")); @@ -66,15 +67,15 @@ local function room_save(room, forced) _data = room._data; _affiliations = room._affiliations; }; - datamanager.store(node, muc_host, "config", data); + room_configs:set(node, data); room._data.history = history; elseif forced then - datamanager.store(node, muc_host, "config", nil); + room_configs:set(node, nil); if not next(room._occupants) then -- Room empty rooms[room.jid] = nil; end end - if forced then datamanager.store(nil, muc_host, "persistent", persistent_rooms); end + if forced then persistent_rooms_storage:set(nil, persistent_rooms); end end function create_room(jid) @@ -88,7 +89,7 @@ end local persistent_errors = false; for jid in pairs(persistent_rooms) do local node = jid_split(jid); - local data = datamanager.load(node, muc_host, "config"); + local data = room_configs:get(node); if data then local room = create_room(jid); room._data = data._data; @@ -99,7 +100,7 @@ for jid in pairs(persistent_rooms) do persistent_errors = true; end end -if persistent_errors then datamanager.store(nil, muc_host, "persistent", persistent_rooms); end +if persistent_errors then persistent_rooms_storage:set(nil, persistent_rooms); end local host_room = muc_new_room(muc_host); host_room.route_stanza = room_route_stanza; @@ -126,9 +127,10 @@ local function handle_to_domain(event) if type == "error" or type == "result" then return; end if stanza.name == "iq" and type == "get" then local xmlns = stanza.tags[1].attr.xmlns; - if xmlns == "http://jabber.org/protocol/disco#info" then + local node = stanza.tags[1].attr.node; + if xmlns == "http://jabber.org/protocol/disco#info" and not node then origin.send(get_disco_info(stanza)); - elseif xmlns == "http://jabber.org/protocol/disco#items" then + elseif xmlns == "http://jabber.org/protocol/disco#items" and not node then origin.send(get_disco_items(stanza)); elseif xmlns == "http://jabber.org/protocol/muc#unique" then origin.send(st.reply(stanza):tag("unique", {xmlns = xmlns}):text(uuid_gen())); -- FIXME Random UUIDs can theoretically have collisions diff --git a/plugins/muc/muc.lib.lua b/plugins/muc/muc.lib.lua index 16a0238d..a5aba3c8 100644 --- a/plugins/muc/muc.lib.lua +++ b/plugins/muc/muc.lib.lua @@ -88,6 +88,10 @@ local function getText(stanza, path) return getUsingPath(stanza, path, true); en local room_mt = {}; room_mt.__index = room_mt; +function room_mt:__tostring() + return "MUC room ("..self.jid..")"; +end + function room_mt:get_default_role(affiliation) if affiliation == "owner" or affiliation == "admin" then return "moderator"; @@ -576,10 +580,9 @@ function room_mt:send_form(origin, stanza) end function room_mt:get_form_layout() - local title = "Configuration for "..self.jid; - return dataform.new({ - title = title, - instructions = title, + local form = dataform.new({ + title = "Configuration for "..self.jid, + instructions = "Complete and submit this form to configure the room.", { name = 'FORM_TYPE', type = 'hidden', @@ -649,6 +652,7 @@ function room_mt:get_form_layout() value = tostring(self:get_historylength()) } }); + return module:fire_event("muc-config-form", { room = self, form = form }) or form; end local valid_whois = { @@ -669,6 +673,10 @@ function room_mt:process_form(origin, stanza) local dirty = false + local event = { room = self, fields = fields, changed = dirty }; + module:fire_event("muc-config-submitted", event); + dirty = event.changed or dirty; + local name = fields['muc#roomconfig_roomname']; if name ~= self:get_name() then self:set_name(name); @@ -765,13 +773,9 @@ function room_mt:handle_to_room(origin, stanza) -- presence changes and groupcha local type = stanza.attr.type; local xmlns = stanza.tags[1] and stanza.tags[1].attr.xmlns; if stanza.name == "iq" then - if xmlns == "http://jabber.org/protocol/disco#info" and type == "get" then - if stanza.tags[1].attr.node then - origin.send(st.error_reply(stanza, "cancel", "feature-not-implemented")); - else - origin.send(self:get_disco_info(stanza)); - end - elseif xmlns == "http://jabber.org/protocol/disco#items" and type == "get" then + if xmlns == "http://jabber.org/protocol/disco#info" and type == "get" and not stanza.tags[1].attr.node then + origin.send(self:get_disco_info(stanza)); + elseif xmlns == "http://jabber.org/protocol/disco#items" and type == "get" and not stanza.tags[1].attr.node then origin.send(self:get_disco_items(stanza)); elseif xmlns == "http://jabber.org/protocol/muc#admin" then local actor = stanza.attr.from; @@ -896,7 +900,7 @@ function room_mt:handle_to_room(origin, stanza) -- presence changes and groupcha origin.send(st.error_reply(stanza, "auth", "forbidden")); end else - self:broadcast_message(stanza, self:get_historylength() > 0); + self:broadcast_message(stanza, self:get_historylength() > 0 and stanza:get_child("body")); end stanza.attr.from = from; end @@ -987,7 +991,7 @@ function room_mt:set_affiliation(actor, jid, affiliation, callback, reason) return true; end if actor_affiliation ~= "owner" then - if actor_affiliation ~= "admin" or target_affiliation == "owner" or target_affiliation == "admin" then + if affiliation == "owner" or affiliation == "admin" or actor_affiliation ~= "admin" or target_affiliation == "owner" or target_affiliation == "admin" then return nil, "cancel", "not-allowed"; end elseif target_affiliation == "owner" and jid_bare(actor) == jid then -- self change @@ -1049,11 +1053,12 @@ function room_mt:get_role(nick) return session and session.role or nil; end function room_mt:can_set_role(actor_jid, occupant_jid, role) - local actor = self._occupants[self._jid_nick[actor_jid]]; local occupant = self._occupants[occupant_jid]; - if not occupant or not actor then return nil, "modify", "not-acceptable"; end + if actor_jid == true then return true; end + + local actor = self._occupants[self._jid_nick[actor_jid]]; if actor.role == "moderator" then if occupant.affiliation ~= "owner" and occupant.affiliation ~= "admin" then if actor.affiliation == "owner" or actor.affiliation == "admin" then diff --git a/plugins/sql.lib.lua b/plugins/sql.lib.lua new file mode 100644 index 00000000..005ee45d --- /dev/null +++ b/plugins/sql.lib.lua @@ -0,0 +1,9 @@ +local cache = module:shared("/*/sql.lib/util.sql"); + +if not cache._M then + prosody.unlock_globals(); + cache._M = require "util.sql"; + prosody.lock_globals(); +end + +return cache._M; @@ -132,8 +132,8 @@ end function sanity_check() for host, host_config in pairs(config.getconfig()) do if host ~= "*" - and host_config.core.enabled ~= false - and not host_config.core.component_module then + and host_config.enabled ~= false + and not host_config.component_module then return; end end @@ -198,6 +198,7 @@ function set_function_metatable() end function init_global_state() + -- COMPAT: These globals are deprecated bare_sessions = {}; full_sessions = {}; hosts = {}; @@ -206,8 +207,8 @@ function init_global_state() prosody.full_sessions = full_sessions; prosody.hosts = hosts; - local data_path = config.get("*", "core", "data_path") or CFG_DATADIR or "data"; - local custom_plugin_paths = config.get("*", "core", "plugin_paths"); + local data_path = config.get("*", "data_path") or CFG_DATADIR or "data"; + local custom_plugin_paths = config.get("*", "plugin_paths"); if custom_plugin_paths then local path_sep = package.config:sub(3,3); -- path1;path2;path3;defaultpath... @@ -289,12 +290,12 @@ function load_secondary_libraries() --- Load and initialise core modules require "util.import" require "util.xmppstream" - require "core.rostermanager" require "core.stanza_router" require "core.hostmanager" require "core.portmanager" require "core.modulemanager" require "core.usermanager" + require "core.rostermanager" require "core.sessionmanager" package.loaded['core.componentmanager'] = setmetatable({},{__index=function() log("warn", "componentmanager is deprecated: %s", debug.traceback():match("\n[^\n]*\n[ \t]*([^\n]*)")); diff --git a/prosody.cfg.lua.dist b/prosody.cfg.lua.dist index 9ca9608a..3c199f3e 100644 --- a/prosody.cfg.lua.dist +++ b/prosody.cfg.lua.dist @@ -41,6 +41,8 @@ modules_enabled = { -- Not essential, but recommended "private"; -- Private XML storage (for room bookmarks, etc.) "vcard"; -- Allow users to set vCards + + -- These are commented by default as they have a performance impact --"privacy"; -- Support privacy lists --"compression"; -- Stream compression @@ -51,7 +53,6 @@ modules_enabled = { "ping"; -- Replies to XMPP pings with pongs "pep"; -- Enables users to publish their mood, activity, playing music and more "register"; -- Allow users to register on this server using a client and change passwords - "adhoc"; -- Support for "ad-hoc commands" that can be executed with an XMPP client -- Admin interfaces "admin_adhoc"; -- Allows administration via an XMPP client that supports ad-hoc commands @@ -71,14 +72,12 @@ modules_enabled = { --"legacyauth"; -- Legacy authentication. Only used by some old clients and bots. }; --- These modules are auto-loaded, should you --- (for some mad reason) want to disable --- them then uncomment them below +-- These modules are auto-loaded, but should you want +-- to disable them then uncomment them here: modules_disabled = { - -- "presence"; -- Route user/contact status information - -- "message"; -- Route messages - -- "iq"; -- Route info queries -- "offline"; -- Store offline messages + -- "c2s"; -- Handle client connections + -- "s2s"; -- Handle server-to-server connections }; -- Disable account creation by default, for security @@ -92,14 +91,28 @@ ssl = { certificate = "certs/localhost.crt"; } --- Only allow encrypted streams? Encryption is already used when --- available. These options will cause Prosody to deny connections that --- are not encrypted. Note that some servers do not support s2s --- encryption or have it disabled, including gmail.com and Google Apps --- domains. +-- Force clients to use encrypted connections? This option will +-- prevent clients from authenticating unless they are using encryption. + +c2s_require_encryption = false + +-- Force certificate authentication for server-to-server connections? +-- This provides ideal security, but requires servers you communicate +-- with to support encryption AND present valid, trusted certificates. +-- For more information see http://prosody.im/doc/s2s#security + +s2s_secure = true + +-- Many servers don't support encryption or have invalid or self-signed +-- certificates. You can list domains here that will not be required to +-- authenticate using certificates. They will be authenticated using DNS. + +-- s2s_insecure_domains = { "gmail.com" } + +-- Even if you leave s2s_secure disabled, you can still require it for +-- some domains by specifying a list here. ---c2s_require_encryption = false ---s2s_require_encryption = false +-- s2s_secure_domains = { "jabber.org" } -- Select the authentication backend to use. The 'internal' providers -- use Prosody's configured data storage to store the authentication data. @@ -51,6 +51,7 @@ local prosody = { lock_globals = function () end; unlock_globals = function () end; installed = CFG_SOURCEDIR ~= nil; + core_post_stanza = function () end; -- TODO: mod_router! }; _G.prosody = prosody; @@ -109,11 +110,11 @@ do os.exit(1); end end -local original_logging_config = config.get("*", "core", "log"); -config.set("*", "core", "log", { { levels = { min="info" }, to = "console" } }); +local original_logging_config = config.get("*", "log"); +config.set("*", "log", { { levels = { min="info" }, to = "console" } }); -local data_path = config.get("*", "core", "data_path") or CFG_DATADIR or "data"; -local custom_plugin_paths = config.get("*", "core", "plugin_paths"); +local data_path = config.get("*", "data_path") or CFG_DATADIR or "data"; +local custom_plugin_paths = config.get("*", "plugin_paths"); if custom_plugin_paths then local path_sep = package.config:sub(3,3); -- path1;path2;path3;defaultpath... @@ -134,7 +135,7 @@ dependencies.log_warnings(); -- Switch away from root and into the prosody user -- local switched_user, current_uid; -local want_pposix_version = "0.3.5"; +local want_pposix_version = "0.3.6"; local ok, pposix = pcall(require, "util.pposix"); if ok and pposix then @@ -142,8 +143,8 @@ if ok and pposix then current_uid = pposix.getuid(); if current_uid == 0 then -- We haz root! - local desired_user = config.get("*", "core", "prosody_user") or "prosody"; - local desired_group = config.get("*", "core", "prosody_group") or desired_user; + local desired_user = config.get("*", "prosody_user") or "prosody"; + local desired_group = config.get("*", "prosody_group") or desired_user; local ok, err = pposix.setgid(desired_group); if ok then ok, err = pposix.initgroups(desired_user); @@ -162,7 +163,7 @@ if ok and pposix then end -- Set our umask to protect data files - pposix.umask(config.get("*", "core", "umask") or "027"); + pposix.umask(config.get("*", "umask") or "027"); pposix.setenv("HOME", data_path); pposix.setenv("PROSODY_CONFIG", ENV_CONFIG); else @@ -267,7 +268,7 @@ local show_yesno = prosodyctl.show_yesno; local show_prompt = prosodyctl.show_prompt; local read_password = prosodyctl.read_password; -local prosodyctl_timeout = (config.get("*", "core", "prosodyctl_timeout") or 5) * 2; +local prosodyctl_timeout = (config.get("*", "prosodyctl_timeout") or 5) * 2; ----------------------- local commands = {}; local command = arg[1]; @@ -410,7 +411,7 @@ function commands.start(arg) local ok, ret = prosodyctl.start(); if ok then - if config.get("*", "core", "daemonize") ~= false then + if config.get("*", "daemonize") ~= false then local i=1; while true do local ok, running = prosodyctl.isrunning(); @@ -653,25 +654,35 @@ end function cert_commands.config(arg) if #arg >= 1 and arg[1] ~= "--help" then - local conf_filename = (CFG_DATADIR or ".") .. "/" .. arg[1] .. ".cnf"; + local conf_filename = (CFG_DATADIR or "./certs") .. "/" .. arg[1] .. ".cnf"; if ask_overwrite(conf_filename) then return nil, conf_filename; end local conf = openssl.config.new(); conf:from_prosody(hosts, config, arg); - for k, v in pairs(conf.distinguished_name) do - local nv; - if k == "commonName" then - v = arg[1] - elseif k == "emailAddress" then - v = "xmpp@" .. arg[1]; - end - nv = show_prompt(("%s (%s):"):format(k, nv or v)); - nv = (not nv or nv == "") and v or nv; - if nv:find"[\192-\252][\128-\191]+" then - conf.req.string_mask = "utf8only" + show_message("Please provide details to include in the certificate config file."); + show_message("Leave the field empty to use the default value or '.' to exclude the field.") + for i, k in ipairs(openssl._DN_order) do + local v = conf.distinguished_name[k]; + if v then + local nv; + if k == "commonName" then + v = arg[1] + elseif k == "emailAddress" then + v = "xmpp@" .. arg[1]; + elseif k == "countryName" then + local tld = arg[1]:match"%.([a-z]+)$"; + if tld and #tld == 2 and tld ~= "uk" then + v = tld:upper(); + end + end + nv = show_prompt(("%s (%s):"):format(k, nv or v)); + nv = (not nv or nv == "") and v or nv; + if nv:find"[\192-\252][\128-\191]+" then + conf.req.string_mask = "utf8only" + end + conf.distinguished_name[k] = nv ~= "." and nv or nil; end - conf.distinguished_name[k] = nv ~= "." and nv or nil; end local conf_file = io.open(conf_filename, "w"); conf_file:write(conf:serialize()); @@ -686,7 +697,7 @@ end function cert_commands.key(arg) if #arg >= 1 and arg[1] ~= "--help" then - local key_filename = (CFG_DATADIR or ".") .. "/" .. arg[1] .. ".key"; + local key_filename = (CFG_DATADIR or "./certs") .. "/" .. arg[1] .. ".key"; if ask_overwrite(key_filename) then return nil, key_filename; end @@ -708,7 +719,7 @@ end function cert_commands.request(arg) if #arg >= 1 and arg[1] ~= "--help" then - local req_filename = (CFG_DATADIR or ".") .. "/" .. arg[1] .. ".req"; + local req_filename = (CFG_DATADIR or "./certs") .. "/" .. arg[1] .. ".req"; if ask_overwrite(req_filename) then return nil, req_filename; end @@ -726,7 +737,7 @@ end function cert_commands.generate(arg) if #arg >= 1 and arg[1] ~= "--help" then - local cert_filename = (CFG_DATADIR or ".") .. "/" .. arg[1] .. ".crt"; + local cert_filename = (CFG_DATADIR or "./certs") .. "/" .. arg[1] .. ".crt"; if ask_overwrite(cert_filename) then return nil, cert_filename; end diff --git a/tests/test_util_sasl_scram.lua b/tests/test_util_sasl_scram.lua index aeae8748..bc89829f 100644 --- a/tests/test_util_sasl_scram.lua +++ b/tests/test_util_sasl_scram.lua @@ -1,6 +1,6 @@ -local hmac_sha1 = require "util.hmac".sha1; +local hmac_sha1 = require "util.hashes".hmac_sha1; local function toHex(s) return s and (s:gsub(".", function (c) return ("%02x"):format(c:byte()); end)); end diff --git a/tools/migration/Makefile b/tools/migration/Makefile index 5998a5f7..ae402bd2 100644 --- a/tools/migration/Makefile +++ b/tools/migration/Makefile @@ -29,7 +29,8 @@ clean: rm -f migrator.cfg.lua.install prosody-migrator.install: prosody-migrator.lua - sed "s|^CFG_SOURCEDIR=.*;$$|CFG_SOURCEDIR='$(INSTALLEDSOURCE)';|; \ + sed "1s/\blua\b/$(RUNWITH)/; \ + s|^CFG_SOURCEDIR=.*;$$|CFG_SOURCEDIR='$(INSTALLEDSOURCE)';|; \ s|^CFG_CONFIGDIR=.*;$$|CFG_CONFIGDIR='$(INSTALLEDCONFIG)';|;" \ < prosody-migrator.lua > prosody-migrator.install diff --git a/util-src/Makefile b/util-src/Makefile index 2c8243f9..90d65e51 100644 --- a/util-src/Makefile +++ b/util-src/Makefile @@ -9,6 +9,7 @@ OPENSSL_LIB?=crypto CC?=gcc CXX?=g++ LD?=gcc +CFLAGS+=-ggdb .PHONY: all install clean .SUFFIXES: .c .o .so diff --git a/util-src/hashes.c b/util-src/hashes.c index 317deaf3..8f7d7140 100644 --- a/util-src/hashes.c +++ b/util-src/hashes.c @@ -14,14 +14,19 @@ */ #include <string.h> +#include <stdlib.h> +#include <inttypes.h> #include "lua.h" #include "lauxlib.h" #include <openssl/sha.h> #include <openssl/md5.h> -const char* hex_tab = "0123456789abcdef"; -void toHex(const char* in, int length, char* out) { +#define HMAC_IPAD 0x36363636 +#define HMAC_OPAD 0x5c5c5c5c + +const char *hex_tab = "0123456789abcdef"; +void toHex(const unsigned char *in, int length, unsigned char *out) { int i; for (i = 0; i < length; i++) { out[i*2] = hex_tab[(in[i] >> 4) & 0xF]; @@ -34,14 +39,13 @@ static int myFunc(lua_State *L) { \ size_t len; \ const char *s = luaL_checklstring(L, 1, &len); \ int hex_out = lua_toboolean(L, 2); \ - char hash[size]; \ - char result[size*2]; \ - func((const unsigned char*)s, len, (unsigned char*)hash); \ + unsigned char hash[size], result[size*2]; \ + func((const unsigned char*)s, len, hash); \ if (hex_out) { \ toHex(hash, size, result); \ - lua_pushlstring(L, result, size*2); \ + lua_pushlstring(L, (char*)result, size*2); \ } else { \ - lua_pushlstring(L, hash, size);\ + lua_pushlstring(L, (char*)hash, size);\ } \ return 1; \ } @@ -53,15 +57,143 @@ MAKE_HASH_FUNCTION(Lsha384, SHA384, SHA384_DIGEST_LENGTH) MAKE_HASH_FUNCTION(Lsha512, SHA512, SHA512_DIGEST_LENGTH) MAKE_HASH_FUNCTION(Lmd5, MD5, MD5_DIGEST_LENGTH) +struct hash_desc { + int (*Init)(void*); + int (*Update)(void*, const void *, size_t); + int (*Final)(unsigned char*, void*); + size_t digestLength; + void *ctx, *ctxo; +}; + +static void hmac(struct hash_desc *desc, const char *key, size_t key_len, + const char *msg, size_t msg_len, unsigned char *result) +{ + union xory { + unsigned char bytes[64]; + uint32_t quadbytes[16]; + }; + + int i; + char hashedKey[64]; /* Maximum used digest length */ + union xory k_ipad, k_opad; + + if (key_len > 64) { + desc->Init(desc->ctx); + desc->Update(desc->ctx, key, key_len); + desc->Final(desc->ctx, hashedKey); + key = (const char*)hashedKey; + key_len = desc->digestLength; + } + + memcpy(k_ipad.bytes, key, key_len); + memset(k_ipad.bytes + key_len, 0, 64 - key_len); + memcpy(k_opad.bytes, k_ipad.bytes, 64); + + for (i = 0; i < 16; i++) { + k_ipad.quadbytes[i] ^= HMAC_IPAD; + k_opad.quadbytes[i] ^= HMAC_OPAD; + } + + desc->Init(desc->ctx); + desc->Update(desc->ctx, k_ipad.bytes, 64); + desc->Init(desc->ctxo); + desc->Update(desc->ctxo, k_opad.bytes, 64); + desc->Update(desc->ctx, msg, msg_len); + desc->Final(result, desc->ctx); + desc->Update(desc->ctxo, result, desc->digestLength); + desc->Final(result, desc->ctxo); +} + +#define MAKE_HMAC_FUNCTION(myFunc, func, size, type) \ +static int myFunc(lua_State *L) { \ + type ctx, ctxo; \ + unsigned char hash[size], result[2*size]; \ + size_t key_len, msg_len; \ + const char *key = luaL_checklstring(L, 1, &key_len); \ + const char *msg = luaL_checklstring(L, 2, &msg_len); \ + const int hex_out = lua_toboolean(L, 3); \ + struct hash_desc desc; \ + desc.Init = (int (*)(void*))func##_Init; \ + desc.Update = (int (*)(void*, const void *, size_t))func##_Update; \ + desc.Final = (int (*)(unsigned char*, void*))func##_Final; \ + desc.digestLength = size; \ + desc.ctx = &ctx; \ + desc.ctxo = &ctxo; \ + hmac(&desc, key, key_len, msg, msg_len, hash); \ + if (hex_out) { \ + toHex(hash, size, result); \ + lua_pushlstring(L, (char*)result, size*2); \ + } else { \ + lua_pushlstring(L, (char*)hash, size); \ + } \ + return 1; \ +} + +MAKE_HMAC_FUNCTION(Lhmac_sha1, SHA1, SHA_DIGEST_LENGTH, SHA_CTX) +MAKE_HMAC_FUNCTION(Lhmac_sha256, SHA256, SHA256_DIGEST_LENGTH, SHA256_CTX) +MAKE_HMAC_FUNCTION(Lhmac_sha512, SHA512, SHA512_DIGEST_LENGTH, SHA512_CTX) +MAKE_HMAC_FUNCTION(Lhmac_md5, MD5, MD5_DIGEST_LENGTH, MD5_CTX) + +static int LscramHi(lua_State *L) { + union xory { + unsigned char bytes[SHA_DIGEST_LENGTH]; + uint32_t quadbytes[SHA_DIGEST_LENGTH/4]; + }; + int i; + SHA_CTX ctx, ctxo; + unsigned char Ust[SHA_DIGEST_LENGTH]; + union xory Und; + union xory res; + size_t str_len, salt_len; + struct hash_desc desc; + const char *str = luaL_checklstring(L, 1, &str_len); + const char *salt = luaL_checklstring(L, 2, &salt_len); + char *salt2; + const int iter = luaL_checkinteger(L, 3); + + desc.Init = (int (*)(void*))SHA1_Init; + desc.Update = (int (*)(void*, const void *, size_t))SHA1_Update; + desc.Final = (int (*)(unsigned char*, void*))SHA1_Final; + desc.digestLength = SHA_DIGEST_LENGTH; + desc.ctx = &ctx; + desc.ctxo = &ctxo; + + salt2 = malloc(salt_len + 4); + if (salt2 == NULL) + luaL_error(L, "Out of memory in scramHi"); + memcpy(salt2, salt, salt_len); + memcpy(salt2 + salt_len, "\0\0\0\1", 4); + hmac(&desc, str, str_len, salt2, salt_len + 4, Ust); + free(salt2); + + memcpy(res.bytes, Ust, sizeof(res)); + for (i = 1; i < iter; i++) { + int j; + hmac(&desc, str, str_len, (char*)Ust, sizeof(Ust), Und.bytes); + for (j = 0; j < SHA_DIGEST_LENGTH/4; j++) + res.quadbytes[j] ^= Und.quadbytes[j]; + memcpy(Ust, Und.bytes, sizeof(Ust)); + } + + lua_pushlstring(L, (char*)res.bytes, SHA_DIGEST_LENGTH); + + return 1; +} + static const luaL_Reg Reg[] = { - { "sha1", Lsha1 }, - { "sha224", Lsha224 }, - { "sha256", Lsha256 }, - { "sha384", Lsha384 }, - { "sha512", Lsha512 }, - { "md5", Lmd5 }, - { NULL, NULL } + { "sha1", Lsha1 }, + { "sha224", Lsha224 }, + { "sha256", Lsha256 }, + { "sha384", Lsha384 }, + { "sha512", Lsha512 }, + { "md5", Lmd5 }, + { "hmac_sha1", Lhmac_sha1 }, + { "hmac_sha256", Lhmac_sha256 }, + { "hmac_sha512", Lhmac_sha512 }, + { "hmac_md5", Lhmac_md5 }, + { "scram_Hi_sha1", LscramHi }, + { NULL, NULL } }; LUALIB_API int luaopen_util_hashes(lua_State *L) diff --git a/util-src/pposix.c b/util-src/pposix.c index 05303d99..f5cc8270 100644 --- a/util-src/pposix.c +++ b/util-src/pposix.c @@ -13,7 +13,7 @@ * POSIX support functions for Lua */ -#define MODULE_VERSION "0.3.5" +#define MODULE_VERSION "0.3.6" #include <stdlib.h> #include <math.h> @@ -204,12 +204,13 @@ int level_constants[] = { }; int lc_syslog_log(lua_State* L) { - int level = luaL_checkoption(L, 1, "notice", level_strings); - level = level_constants[level]; + int level = level_constants[luaL_checkoption(L, 1, "notice", level_strings)]; - luaL_checkstring(L, 2); + if(lua_gettop(L) == 3) + syslog(level, "%s: %s", luaL_checkstring(L, 2), luaL_checkstring(L, 3)); + else + syslog(level, "%s", lua_tostring(L, 2)); - syslog(level, "%s", lua_tostring(L, 2)); return 0; } @@ -484,6 +485,9 @@ int string2resource(const char *s) { if (!strcmp(s, "NPROC")) return RLIMIT_NPROC; if (!strcmp(s, "RSS")) return RLIMIT_RSS; #endif +#ifdef RLIMIT_NICE + if (!strcmp(s, "NICE")) return RLIMIT_NICE; +#endif return -1; } diff --git a/util/adhoc.lua b/util/adhoc.lua new file mode 100644 index 00000000..671e85cf --- /dev/null +++ b/util/adhoc.lua @@ -0,0 +1,31 @@ +local function new_simple_form(form, result_handler) + return function(self, data, state) + if state then + if data.action == "cancel" then + return { status = "canceled" }; + end + local fields, err = form:data(data.form); + return result_handler(fields, err, data); + else + return { status = "executing", actions = {"next", "complete", default = "complete"}, form = form }, "executing"; + end + end +end + +local function new_initial_data_form(form, initial_data, result_handler) + return function(self, data, state) + if state then + if data.action == "cancel" then + return { status = "canceled" }; + end + local fields, err = form:data(data.form); + return result_handler(fields, err, data); + else + return { status = "executing", actions = {"next", "complete", default = "complete"}, + form = { layout = form, values = initial_data() } }, "executing"; + end + end +end + +return { new_simple_form = new_simple_form, + new_initial_data_form = new_initial_data_form }; diff --git a/util/datamanager.lua b/util/datamanager.lua index 383e738f..4a4d62b3 100644 --- a/util/datamanager.lua +++ b/util/datamanager.lua @@ -187,17 +187,25 @@ function store(username, host, datastore, data) -- save the datastore local d = "return " .. serialize(data) .. ";\n"; - local ok, msg = atomic_store(getpath(username, host, datastore, nil, true), d); - if not ok then - log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil"); - return nil, "Error saving to storage"; - end - if next(data) == nil then -- try to delete empty datastore - log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil"); - os_remove(getpath(username, host, datastore)); - end - -- we write data even when we are deleting because lua doesn't have a - -- platform independent way of checking for non-exisitng files + local mkdir_cache_cleared; + repeat + local ok, msg = atomic_store(getpath(username, host, datastore, nil, true), d); + if not ok then + if not mkdir_cache_cleared then -- We may need to recreate a removed directory + _mkdir = {}; + mkdir_cache_cleared = true; + else + log("error", "Unable to write to %s storage ('%s') for user: %s@%s", datastore, msg, username or "nil", host or "nil"); + return nil, "Error saving to storage"; + end + end + if next(data) == nil then -- try to delete empty datastore + log("debug", "Removing empty %s datastore for user %s@%s", datastore, username or "nil", host or "nil"); + os_remove(getpath(username, host, datastore)); + end + -- we write data even when we are deleting because lua doesn't have a + -- platform independent way of checking for non-exisitng files + until ok; return true; end @@ -354,4 +362,6 @@ function purge(username, host) return #errs == 0, t_concat(errs, ", "); end +_M.path_decode = decode; +_M.path_encode = encode; return _M; diff --git a/util/helpers.lua b/util/helpers.lua index 6103a319..08b86a7c 100644 --- a/util/helpers.lua +++ b/util/helpers.lua @@ -14,6 +14,14 @@ module("helpers", package.seeall); local log = require "util.logger".init("util.debug"); +function log_host_events(host) + return log_events(prosody.hosts[host].events, host); +end + +function revert_log_host_events(host) + return revert_log_events(prosody.hosts[host].events); +end + function log_events(events, name, logger) local f = events.fire_event; if not f then diff --git a/util/hmac.lua b/util/hmac.lua index 6df6986e..51211c7a 100644 --- a/util/hmac.lua +++ b/util/hmac.lua @@ -6,64 +6,10 @@ -- COPYING file in the source package for more information. -- -local hashes = require "util.hashes" - -local s_char = string.char; -local s_gsub = string.gsub; -local s_rep = string.rep; - -module "hmac" - -local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;}; -local function xor(x, y) - local lowx, lowy = x % 16, y % 16; - local hix, hiy = (x - lowx) / 16, (y - lowy) / 16; - local lowr, hir = xor_map[lowx * 16 + lowy + 1], xor_map[hix * 16 + hiy + 1]; - local r = hir * 16 + lowr; - return r; -end -local opadc, ipadc = s_char(0x5c), s_char(0x36); -local ipad_map = {}; -local opad_map = {}; -for i=0,255 do - ipad_map[s_char(i)] = s_char(xor(0x36, i)); - opad_map[s_char(i)] = s_char(xor(0x5c, i)); -end - ---[[ -key - the key to use in the hash -message - the message to hash -hash - the hash function -blocksize - the blocksize for the hash function in bytes -hex - return raw hash or hexadecimal string ---]] -function hmac(key, message, hash, blocksize, hex) - if #key > blocksize then - key = hash(key) - end +-- COMPAT: Only for external pre-0.9 modules - local padding = blocksize - #key; - local ipad = s_gsub(key, ".", ipad_map)..s_rep(ipadc, padding); - local opad = s_gsub(key, ".", opad_map)..s_rep(opadc, padding); - - return hash(opad..hash(ipad..message), hex) -end - -function md5(key, message, hex) - return hmac(key, message, hashes.md5, 64, hex) -end - -function sha1(key, message, hex) - return hmac(key, message, hashes.sha1, 64, hex) -end - -function sha256(key, message, hex) - return hmac(key, message, hashes.sha256, 64, hex) -end +local hashes = require "util.hashes" -return _M +return { md5 = hashes.hmac_md5, + sha1 = hashes.hmac_sha1, + sha256 = hashes.hmac_sha256 }; diff --git a/util/http.lua b/util/http.lua index 5b49d1d0..f7259920 100644 --- a/util/http.lua +++ b/util/http.lua @@ -5,11 +5,60 @@ -- COPYING file in the source package for more information. -- -local http = {}; +local format, char = string.format, string.char; +local pairs, ipairs, tonumber = pairs, ipairs, tonumber; +local t_insert, t_concat = table.insert, table.concat; -function http.contains_token(field, token) +local function urlencode(s) + return s and (s:gsub("[^a-zA-Z0-9.~_-]", function (c) return format("%%%02x", c:byte()); end)); +end +local function urldecode(s) + return s and (s:gsub("%%(%x%x)", function (c) return char(tonumber(c,16)); end)); +end + +local function _formencodepart(s) + return s and (s:gsub("%W", function (c) + if c ~= " " then + return format("%%%02x", c:byte()); + else + return "+"; + end + end)); +end + +local function formencode(form) + local result = {}; + if form[1] then -- Array of ordered { name, value } + for _, field in ipairs(form) do + t_insert(result, _formencodepart(field.name).."=".._formencodepart(field.value)); + end + else -- Unordered map of name -> value + for name, value in pairs(form) do + t_insert(result, _formencodepart(name).."=".._formencodepart(value)); + end + end + return t_concat(result, "&"); +end + +local function formdecode(s) + if not s:match("=") then return urldecode(s); end + local r = {}; + for k, v in s:gmatch("([^=&]*)=([^&]*)") do + k, v = k:gsub("%+", "%%20"), v:gsub("%+", "%%20"); + k, v = urldecode(k), urldecode(v); + t_insert(r, { name = k, value = v }); + r[k] = v; + end + return r; +end + +local function contains_token(field, token) field = ","..field:gsub("[ \t]", ""):lower()..","; return field:find(","..token:lower()..",", 1, true) ~= nil; end -return http; +return { + urlencode = urlencode, urldecode = urldecode; + formencode = formencode, formdecode = formdecode; + contains_token = contains_token; +}; diff --git a/util/httpstream.lua b/util/httpstream.lua deleted file mode 100644 index 190b3ed6..00000000 --- a/util/httpstream.lua +++ /dev/null @@ -1,134 +0,0 @@ - -local coroutine = coroutine; -local tonumber = tonumber; - -local deadroutine = coroutine.create(function() end); -coroutine.resume(deadroutine); - -module("httpstream") - -local function parser(success_cb, parser_type, options_cb) - local data = coroutine.yield(); - local function readline() - local pos = data:find("\r\n", nil, true); - while not pos do - data = data..coroutine.yield(); - pos = data:find("\r\n", nil, true); - end - local r = data:sub(1, pos-1); - data = data:sub(pos+2); - return r; - end - local function readlength(n) - while #data < n do - data = data..coroutine.yield(); - end - local r = data:sub(1, n); - data = data:sub(n + 1); - return r; - end - local function readheaders() - local headers = {}; -- read headers - while true do - local line = readline(); - if line == "" then break; end -- headers done - local key, val = line:match("^([^%s:]+): *(.*)$"); - if not key then coroutine.yield("invalid-header-line"); end -- TODO handle multi-line and invalid headers - key = key:lower(); - headers[key] = headers[key] and headers[key]..","..val or val; - end - return headers; - end - - if not parser_type or parser_type == "server" then - while true do - -- read status line - local status_line = readline(); - local method, path, httpversion = status_line:match("^(%S+)%s+(%S+)%s+HTTP/(%S+)$"); - if not method then coroutine.yield("invalid-status-line"); end - path = path:gsub("^//+", "/"); -- TODO parse url more - local headers = readheaders(); - - -- read body - local len = tonumber(headers["content-length"]); - len = len or 0; -- TODO check for invalid len - local body = readlength(len); - - success_cb({ - method = method; - path = path; - httpversion = httpversion; - headers = headers; - body = body; - }); - end - elseif parser_type == "client" then - while true do - -- read status line - local status_line = readline(); - local httpversion, status_code, reason_phrase = status_line:match("^HTTP/(%S+)%s+(%d%d%d)%s+(.*)$"); - status_code = tonumber(status_code); - if not status_code then coroutine.yield("invalid-status-line"); end - local headers = readheaders(); - - -- read body - local have_body = not - ( (options_cb and options_cb().method == "HEAD") - or (status_code == 204 or status_code == 304 or status_code == 301) - or (status_code >= 100 and status_code < 200) ); - - local body; - if have_body then - local len = tonumber(headers["content-length"]); - if headers["transfer-encoding"] == "chunked" then - body = ""; - while true do - local chunk_size = readline():match("^%x+"); - if not chunk_size then coroutine.yield("invalid-chunk-size"); end - chunk_size = tonumber(chunk_size, 16) - if chunk_size == 0 then break; end - body = body..readlength(chunk_size); - if readline() ~= "" then coroutine.yield("invalid-chunk-ending"); end - end - local trailers = readheaders(); - elseif len then -- TODO check for invalid len - body = readlength(len); - else -- read to end - repeat - local newdata = coroutine.yield(); - data = data..newdata; - until newdata == ""; - body, data = data, ""; - end - end - - success_cb({ - code = status_code; - httpversion = httpversion; - headers = headers; - body = body; - }); - end - else coroutine.yield("unknown-parser-type"); end -end - -function new(success_cb, error_cb, parser_type, options_cb) - local co = coroutine.create(parser); - coroutine.resume(co, success_cb, parser_type, options_cb) - return { - feed = function(self, data) - if not data then - if parser_type == "client" then coroutine.resume(co, ""); end - co = deadroutine; - return error_cb(); - end - local success, result = coroutine.resume(co, data); - if result then - co = deadroutine; - return error_cb(result); - end - end; - }; -end - -return _M; diff --git a/util/ip.lua b/util/ip.lua index 2f09c034..de287b16 100644 --- a/util/ip.lua +++ b/util/ip.lua @@ -64,9 +64,6 @@ local function v4scope(ip) -- Link-local unicast: elseif fields[1] == 169 and fields[2] == 254 then return 0x2; - -- Site-local unicast: - elseif (fields[1] == 10) or (fields[1] == 192 and fields[2] == 168) or (fields[1] == 172 and (fields[2] >= 16 and fields[2] < 32)) then - return 0x5; -- Global unicast: else return 0xE; @@ -97,6 +94,14 @@ local function label(ip) return 0; elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then return 2; + elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + return 5; + elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + return 13; + elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + return 11; + elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + return 12; elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then return 3; elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then @@ -111,10 +116,18 @@ local function precedence(ip) return 50; elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then return 30; + elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + return 5; + elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + return 3; + elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + return 1; + elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + return 1; elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then - return 20; + return 1; elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then - return 10; + return 35; else return 40; end diff --git a/util/iterators.lua b/util/iterators.lua index fb89f4a5..1f6aacb8 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -122,6 +122,11 @@ function it.tail(n, f, s, var) --return reverse(head(n, reverse(f, s, var))); end +local function _ripairs_iter(t, key) if key > 1 then return key-1, t[key-1]; end end +function it.ripairs(t) + return _ripairs_iter, t, #t+1; +end + local function _range_iter(max, curr) if curr < max then return curr + 1; end end function it.range(x, y) if not y then x, y = 1, x; end -- Default to 1..x if y not given diff --git a/util/json.lua b/util/json.lua index efc602f0..9c2dd2c6 100644 --- a/util/json.lua +++ b/util/json.lua @@ -1,3 +1,10 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- local type = type; local t_insert, t_concat, t_remove, t_sort = table.insert, table.concat, table.remove, table.sort; @@ -9,6 +16,9 @@ local error = error; local newproxy, getmetatable = newproxy, getmetatable; local print = print; +local has_array, array = pcall(require, "util.array"); +local array_mt = hasarray and getmetatable(array()) or {}; + --module("json") local json = {}; @@ -29,6 +39,19 @@ for i=0,31 do if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end end +local function codepoint_to_utf8(code) + if code < 0x80 then return s_char(code); end + local bits0_6 = code % 64; + if code < 0x800 then + local bits6_5 = (code - bits0_6) / 64; + return s_char(0x80 + 0x40 + bits6_5, 0x80 + bits0_6); + end + local bits0_12 = code % 4096; + local bits6_6 = (bits0_12 - bits0_6) / 64; + local bits12_4 = (code - bits0_12) / 4096; + return s_char(0x80 + 0x40 + 0x20 + bits12_4, 0x80 + bits6_6, 0x80 + bits0_6); +end + local valid_types = { number = true, string = true, @@ -130,7 +153,12 @@ function simplesave(o, buffer) elseif t == "string" then stringsave(o, buffer); elseif t == "table" then - tablesave(o, buffer); + local mt = getmetatable(o); + if mt == array_mt then + arraysave(o, buffer); + else + tablesave(o, buffer); + end elseif t == "boolean" then t_insert(buffer, (o and "true" or "false")); else @@ -148,6 +176,11 @@ function json.encode_ordered(obj) simplesave(obj, t); return t_concat(t); end +function json.encode_array(obj) + local t = {}; + arraysave(obj, t); + return t_concat(t); +end ----------------------------------- @@ -197,7 +230,7 @@ function json.decode(json) local readvalue; local function readarray() - local t = {}; + local t = setmetatable({}, array_mt); next(); -- skip '[' skipstuff(); if ch == "]" then next(); return t; end @@ -244,7 +277,7 @@ function json.decode(json) if not ch:match("[0-9a-fA-F]") then error("invalid unicode escape sequence in string"); end seq = seq..ch; end - s = s..s.char(tonumber(seq, 16)); -- FIXME do proper utf-8 + s = s..codepoint_to_utf8(tonumber(seq, 16)); next(); else error("invalid escape sequence in string"); end end diff --git a/util/openssl.lua b/util/openssl.lua index b3dc2943..ef3fba96 100644 --- a/util/openssl.lua +++ b/util/openssl.lua @@ -23,11 +23,12 @@ function config.new() prompt = "no", }, distinguished_name = { - commonName = "example.com", countryName = "GB", + -- stateOrProvinceName = "", localityName = "The Internet", organizationName = "Your Organisation", organizationalUnitName = "XMPP Department", + commonName = "example.com", emailAddress = "xmpp@example.com", }, v3_extensions = { @@ -43,6 +44,17 @@ function config.new() }, ssl_config_mt); end +local DN_order = { + "countryName"; + "stateOrProvinceName"; + "localityName"; + "streetAddress"; + "organizationName"; + "organizationalUnitName"; + "commonName"; + "emailAddress"; +} +_M._DN_order = DN_order; function ssl_config:serialize() local s = ""; for k, t in pairs(self) do @@ -53,6 +65,14 @@ function ssl_config:serialize() s = s .. s_format("%s.%d = %s\n", san, i -1, n[i]); end end + elseif k == "distinguished_name" then + for i=1,#DN_order do + local k = DN_order[i] + local v = t[k]; + if v then + s = s .. ("%s = %s\n"):format(k, v); + end + end else for k, v in pairs(t) do s = s .. ("%s = %s\n"):format(k, v); @@ -100,13 +120,13 @@ function ssl_config:from_prosody(hosts, config, certhosts) if name == certhost or name:sub(-1-#certhost) == "."..certhost then found_matching_hosts = true; self:add_dNSName(name); - --print(name .. "#component_module: " .. (config.get(name, "core", "component_module") or "nil")); - if config.get(name, "core", "component_module") == nil then + --print(name .. "#component_module: " .. (config.get(name, "component_module") or "nil")); + if config.get(name, "component_module") == nil then self:add_sRVName(name, "xmpp-client"); end - --print(name .. "#anonymous_login: " .. tostring(config.get(name, "core", "anonymous_login"))); - if not (config.get(name, "core", "anonymous_login") or - config.get(name, "core", "authentication") == "anonymous") then + --print(name .. "#anonymous_login: " .. tostring(config.get(name, "anonymous_login"))); + if not (config.get(name, "anonymous_login") or + config.get(name, "authentication") == "anonymous") then self:add_sRVName(name, "xmpp-server"); end self:add_xmppAddr(name); diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index e38f85d4..b80a69f2 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -140,11 +140,12 @@ function adduser(params) if not host_session then return false, "no-such-host"; end + + storagemanager.initialize_host(host); local provider = host_session.users; if not(provider) or provider.name == "null" then usermanager.initialize_host(host); end - storagemanager.initialize_host(host); local ok, errmsg = usermanager.create_user(user, password, host); if not ok then @@ -155,11 +156,12 @@ end function user_exists(params) local user, host, password = nodeprep(params.user), nameprep(params.host), params.password; + + storagemanager.initialize_host(host); local provider = prosody.hosts[host].users; if not(provider) or provider.name == "null" then usermanager.initialize_host(host); end - storagemanager.initialize_host(host); return usermanager.user_exists(user, host); end @@ -182,12 +184,12 @@ function deluser(params) end function getpid() - local pidfile = config.get("*", "core", "pidfile"); + local pidfile = config.get("*", "pidfile"); if not pidfile then return false, "no-pidfile"; end - local modules_enabled = set.new(config.get("*", "core", "modules_enabled")); + local modules_enabled = set.new(config.get("*", "modules_enabled")); if not modules_enabled:contains("posix") then return false, "no-posix"; end diff --git a/util/rfc3484.lua b/util/rfc6724.lua index 5ee572a0..c8aec631 100644 --- a/util/rfc3484.lua +++ b/util/rfc6724.lua @@ -1,13 +1,22 @@ -- Prosody IM --- Copyright (C) 2008-2011 Florian Zeitz +-- Copyright (C) 2011-2013 Florian Zeitz -- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- -local commonPrefixLength = require"util.ip".commonPrefixLength +-- This is used to sort destination addresses by preference +-- during S2S connections. +-- We can't hand this off to getaddrinfo, since it blocks + +local ip_commonPrefixLength = require"util.ip".commonPrefixLength local new_ip = require"util.ip".new_ip; +local function commonPrefixLength(ipA, ipB) + local len = ip_commonPrefixLength(ipA, ipB); + return len < 64 and len or 64; +end + local function t_sort(t, comp) for i = 1, (#t - 1) do for j = (i + 1), #t do @@ -56,7 +65,7 @@ local function source(dest, candidates) return false; end - -- Rule 7: Prefer public addresses (over temporary ones) + -- Rule 7: Prefer temporary addresses (over public ones) -- XXX: No way to determine this -- Rule 8: Use longest matching prefix if commonPrefixLength(ipA, dest) > commonPrefixLength(ipB, dest) then diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index d0e8987c..cf2f0ede 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -15,8 +15,9 @@ local s_match = string.match; local type = type local string = string local base64 = require "util.encodings".base64; -local hmac_sha1 = require "util.hmac".sha1; +local hmac_sha1 = require "util.hashes".hmac_sha1; local sha1 = require "util.hashes".sha1; +local Hi = require "util.hashes".scram_Hi_sha1; local generate_uuid = require "util.uuid".generate; local saslprep = require "util.encodings".stringprep.saslprep; local nodeprep = require "util.encodings".stringprep.nodeprep; @@ -65,18 +66,6 @@ local function binaryXOR( a, b ) return t_concat(result); end --- hash algorithm independent Hi(PBKDF2) implementation -function Hi(hmac, str, salt, i) - local Ust = hmac(str, salt.."\0\0\0\1"); - local res = Ust; - for n=1,i-1 do - local Und = hmac(str, Ust) - res = binaryXOR(res, Und) - Ust = Und - end - return res -end - local function validate_username(username, _nodeprep) -- check for forbidden char sequences for eq in username:gmatch("=(.?.?)") do @@ -110,7 +99,7 @@ function getAuthenticationDatabaseSHA1(password, salt, iteration_count) if iteration_count < 4096 then log("warn", "Iteration count < 4096 which is the suggested minimum according to RFC 5802.") end - local salted_password = Hi(hmac_sha1, password, salt, iteration_count); + local salted_password = Hi(password, salt, iteration_count); local stored_key = sha1(hmac_sha1(salted_password, "Client Key")) local server_key = hmac_sha1(salted_password, "Server Key"); return true, stored_key, server_key diff --git a/util/sql.lua b/util/sql.lua new file mode 100644 index 00000000..f360d6d0 --- /dev/null +++ b/util/sql.lua @@ -0,0 +1,340 @@ + +local setmetatable, getmetatable = setmetatable, getmetatable; +local ipairs, unpack, select = ipairs, unpack, select; +local tonumber, tostring = tonumber, tostring; +local assert, xpcall, debug_traceback = assert, xpcall, debug.traceback; +local t_concat = table.concat; +local s_char = string.char; +local log = require "util.logger".init("sql"); + +local DBI = require "DBI"; +-- This loads all available drivers while globals are unlocked +-- LuaDBI should be fixed to not set globals. +DBI.Drivers(); +local build_url = require "socket.url".build; + +module("sql") + +local column_mt = {}; +local table_mt = {}; +local query_mt = {}; +--local op_mt = {}; +local index_mt = {}; + +function is_column(x) return getmetatable(x)==column_mt; end +function is_index(x) return getmetatable(x)==index_mt; end +function is_table(x) return getmetatable(x)==table_mt; end +function is_query(x) return getmetatable(x)==query_mt; end +--function is_op(x) return getmetatable(x)==op_mt; end +--function expr(...) return setmetatable({...}, op_mt); end +function Integer(n) return "Integer()" end +function String(n) return "String()" end + +--[[local ops = { + __add = function(a, b) return "("..a.."+"..b..")" end; + __sub = function(a, b) return "("..a.."-"..b..")" end; + __mul = function(a, b) return "("..a.."*"..b..")" end; + __div = function(a, b) return "("..a.."/"..b..")" end; + __mod = function(a, b) return "("..a.."%"..b..")" end; + __pow = function(a, b) return "POW("..a..","..b..")" end; + __unm = function(a) return "NOT("..a..")" end; + __len = function(a) return "COUNT("..a..")" end; + __eq = function(a, b) return "("..a.."=="..b..")" end; + __lt = function(a, b) return "("..a.."<"..b..")" end; + __le = function(a, b) return "("..a.."<="..b..")" end; +}; + +local functions = { + +}; + +local cmap = { + [Integer] = Integer(); + [String] = String(); +};]] + +function Column(definition) + return setmetatable(definition, column_mt); +end +function Table(definition) + local c = {} + for i,col in ipairs(definition) do + if is_column(col) then + c[i], c[col.name] = col, col; + elseif is_index(col) then + col.table = definition.name; + end + end + return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt); +end +function Index(definition) + return setmetatable(definition, index_mt); +end + +function table_mt:__tostring() + local s = { 'name="'..self.__table__.name..'"' } + for i,col in ipairs(self.__table__) do + s[#s+1] = tostring(col); + end + return 'Table{ '..t_concat(s, ", ")..' }' +end +table_mt.__index = {}; +function table_mt.__index:create(engine) + return engine:_create_table(self); +end +function table_mt:__call(...) + -- TODO +end +function column_mt:__tostring() + return 'Column{ name="'..self.name..'", type="'..self.type..'" }' +end +function index_mt:__tostring() + local s = 'Index{ name="'..self.name..'"'; + for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end + return s..' }'; +-- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' +end +-- + +local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end +local function parse_url(url) + local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)"); + assert(scheme, "Invalid URL format"); + local username, password, host, port; + local authpart, hostpart = secondpart:match("([^@]+)@([^@+])"); + if not authpart then hostpart = secondpart; end + if authpart then + username, password = authpart:match("([^:]*):(.*)"); + username = username or authpart; + password = password and urldecode(password); + end + if hostpart then + host, port = hostpart:match("([^:]*):(.*)"); + host = host or hostpart; + port = port and assert(tonumber(port), "Invalid URL format"); + end + return { + scheme = scheme:lower(); + username = username; password = password; + host = host; port = port; + database = #database > 0 and database or nil; + }; +end + +--[[local session = {}; + +function session.query(...) + local rets = {...}; + local query = setmetatable({ __rets = rets, __filters }, query_mt); + return query; +end +-- + +local function db2uri(params) + return build_url{ + scheme = params.driver, + user = params.username, + password = params.password, + host = params.host, + port = params.port, + path = params.database, + }; +end]] + +local engine = {}; +function engine:connect() + if self.conn then return true; end + + local params = self.params; + assert(params.driver, "no driver") + local dbh, err = DBI.Connect( + params.driver, params.database, + params.username, params.password, + params.host, params.port + ); + if not dbh then return nil, err; end + dbh:autocommit(false); -- don't commit automatically + self.conn = dbh; + self.prepared = {}; + return true; +end +function engine:execute(sql, ...) + local success, err = self:connect(); + if not success then return success, err; end + local prepared = self.prepared; + + local stmt = prepared[sql]; + if not stmt then + local err; + stmt, err = self.conn:prepare(sql); + if not stmt then return stmt, err; end + prepared[sql] = stmt; + end + + local success, err = stmt:execute(...); + if not success then return success, err; end + return stmt; +end + +local result_mt = { __index = { + affected = function(self) return self.__affected; end; + rowcount = function(self) return self.__rowcount; end; +} }; + +function engine:execute_query(sql, ...) + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + end + local stmt = assert(self.conn:prepare(sql)); + assert(stmt:execute(...)); + return stmt:rows(); +end +function engine:execute_update(sql, ...) + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + end + local prepared = self.prepared; + local stmt = prepared[sql]; + if not stmt then + stmt = assert(self.conn:prepare(sql)); + prepared[sql] = stmt; + end + assert(stmt:execute(...)); + return setmetatable({ __affected = stmt:affected(), __rowcount = stmt:rowcount() }, result_mt); +end +engine.insert = engine.execute_update; +engine.select = engine.execute_query; +engine.delete = engine.execute_update; +engine.update = engine.execute_update; +function engine:_transaction(func, ...) + if not self.conn then + local a,b = self:connect(); + if not a then return a,b; end + end + --assert(not self.__transaction, "Recursive transactions not allowed"); + local args, n_args = {...}, select("#", ...); + local function f() return func(unpack(args, 1, n_args)); end + self.__transaction = true; + local success, a, b, c = xpcall(f, debug_traceback); + self.__transaction = nil; + if success then + log("debug", "SQL transaction success [%s]", tostring(func)); + local ok, err = self.conn:commit(); + if not ok then return ok, err; end -- commit failed + return success, a, b, c; + else + log("debug", "SQL transaction failure [%s]: %s", tostring(func), a); + if self.conn then self.conn:rollback(); end + return success, a; + end +end +function engine:transaction(...) + local a,b = self:_transaction(...); + if not a then + local conn = self.conn; + if not conn or not conn:ping() then + self.conn = nil; + a,b = self:_transaction(...); + end + end + return a,b; +end +function engine:_create_index(index) + local sql = "CREATE INDEX `"..index.name.."` ON `"..index.table.."` ("; + for i=1,#index do + sql = sql.."`"..index[i].."`"; + if i ~= #index then sql = sql..", "; end + end + sql = sql..");" + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + elseif self.params.driver == "MySQL" then + sql = sql:gsub("`([,)])", "`(20)%1"); + end + --print(sql); + return self:execute(sql); +end +function engine:_create_table(table) + local sql = "CREATE TABLE `"..table.name.."` ("; + for i,col in ipairs(table.c) do + sql = sql.."`"..col.name.."` "..col.type; + if col.nullable == false then sql = sql.." NOT NULL"; end + if i ~= #table.c then sql = sql..", "; end + end + sql = sql.. ");" + if self.params.driver == "PostgreSQL" then + sql = sql:gsub("`", "\""); + end + local success,err = self:execute(sql); + if not success then return success,err; end + for i,v in ipairs(table.__table__) do + if is_index(v) then + self:_create_index(v); + end + end + return success; +end +local engine_mt = { __index = engine }; + +local function db2uri(params) + return build_url{ + scheme = params.driver, + user = params.username, + password = params.password, + host = params.host, + port = params.port, + path = params.database, + }; +end +local engine_cache = {}; -- TODO make weak valued +function create_engine(self, params) + local url = db2uri(params); + if not engine_cache[url] then + local engine = setmetatable({ url = url, params = params }, engine_mt); + engine_cache[url] = engine; + end + return engine_cache[url]; +end + + +--[[Users = Table { + name="users"; + Column { name="user_id", type=String(), primary_key=true }; +}; +print(Users) +print(Users.c.user_id)]] + +--local engine = create_engine('postgresql://scott:tiger@localhost:5432/mydatabase'); +--[[local engine = create_engine{ driver = "SQLite3", database = "./alchemy.sqlite" }; + +local i = 0; +for row in assert(engine:execute("select * from sqlite_master")):rows(true) do + i = i+1; + print(i); + for k,v in pairs(row) do + print("",k,v); + end +end +print("---") + +Prosody = Table { + name="prosody"; + Column { name="host", type="TEXT", nullable=false }; + Column { name="user", type="TEXT", nullable=false }; + Column { name="store", type="TEXT", nullable=false }; + Column { name="key", type="TEXT", nullable=false }; + Column { name="type", type="TEXT", nullable=false }; + Column { name="value", type="TEXT", nullable=false }; + Index { name="prosody_index", "host", "user", "store", "key" }; +}; +--print(Prosody); +assert(engine:transaction(function() + assert(Prosody:create(engine)); +end)); + +for row in assert(engine:execute("select user from prosody")):rows(true) do + print("username:", row['username']) +end +--result.close();]] + +return _M; diff --git a/util/stanza.lua b/util/stanza.lua index a0ab2a5a..7c214210 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -18,6 +18,7 @@ local pairs = pairs; local ipairs = ipairs; local type = type; local s_gsub = string.gsub; +local s_sub = string.sub; local s_find = string.find; local os = os; @@ -153,7 +154,7 @@ function stanza_mt:maptags(callback) local n_children, n_tags = #self, #tags; local i = 1; - while curr_tag <= n_tags do + while curr_tag <= n_tags and n_tags > 0 do if self[i] == tags[curr_tag] then local ret = callback(self[i]); if ret == nil then @@ -161,17 +162,44 @@ function stanza_mt:maptags(callback) t_remove(tags, curr_tag); n_children = n_children - 1; n_tags = n_tags - 1; + i = i - 1; + curr_tag = curr_tag - 1; else self[i] = ret; - tags[i] = ret; + tags[curr_tag] = ret; end - i = i + 1; curr_tag = curr_tag + 1; end + i = i + 1; end return self; end +function stanza_mt:find(path) + local pos = 1; + local len = #path + 1; + + repeat + local xmlns, name, text; + local char = s_sub(path, pos, pos); + if char == "@" then + return self.attr[s_sub(path, pos + 1)]; + elseif char == "{" then + xmlns, pos = s_match(path, "^([^}]+)}()", pos + 1); + end + name, text, pos = s_match(path, "^([^@/#]*)([/#]?)()", pos); + name = name ~= "" and name or nil; + if pos == len then + if text == "#" then + return self:get_child_text(name, xmlns); + end + return self:get_child(name, xmlns); + end + self = self:get_child(name, xmlns); + until not self +end + + local xml_escape do local escape_table = { ["'"] = "'", ["\""] = """, ["<"] = "<", [">"] = ">", ["&"] = "&" }; |