diff options
250 files changed, 12768 insertions, 5307 deletions
@@ -15,7 +15,6 @@ config.unix *.rej *.save *~ -*.report *.o *.so *.install @@ -27,3 +26,6 @@ config.unix *.exp *.lib *.obj +luacov.report.out +luacov.report.out.index +luacov.stats.out
\ No newline at end of file diff --git a/.luacheckrc b/.luacheckrc index 91face3f..0764ce92 100644 --- a/.luacheckrc +++ b/.luacheckrc @@ -1,19 +1,34 @@ cache = true -read_globals = { "prosody", "hosts", "import" } -globals = { "_M" } -allow_defined_top = true -module = true unused_secondaries = false codes = true -ignore = { "411/err", "421/err", "411/ok", "421/ok", "211/_ENV", "431/log" } +ignore = { "411/err", "421/err", "411/ok", "421/ok", "211/_ENV", "431/log", "143/table", "113/unpack" } max_line_length = 150 +read_globals = { + "prosody", + "import", +}; +files["prosody"] = { + allow_defined_top = true; + module = true; +} +files["prosodyctl"] = { + allow_defined_top = true; + module = true; +}; files["core/"] = { - read_globals = { "prosody", "hosts" }; - globals = { "prosody.hosts.?", "hosts.?" }; + globals = { + "prosody.hosts.?", + }; +} +files["util/"] = { + -- Ignore unwrapped license text + max_comment_line_length = false; } files["plugins/"] = { + module = true; + allow_defined_top = true; read_globals = { -- Module instance "module.name", @@ -51,8 +66,6 @@ files["plugins/"] = { "module.get_option_set", "module.get_option_string", "module.handle_items", - "module.has_feature", - "module.has_identity", "module.hook", "module.hook_global", "module.hook_object_event", @@ -74,10 +87,11 @@ files["plugins/"] = { "module.wrap_event", "module.wrap_global", "module.wrap_object_event", + + -- mod_http API + "module.http_url", }; globals = { - "_M", - -- Methods that can be set on module API "module.unload", "module.add_host", @@ -89,16 +103,87 @@ files["plugins/"] = { "module.environment", }; } -files["tests/"] = { - read_globals = { - "testlib_new_env", - "assert_equal", - "assert_table", - "assert_function", - "assert_string", - "assert_boolean", - "assert_is", - "assert_is_not", - "runtest", +files["spec/"] = { + std = "+busted" +} +files["prosody.cfg.lua"] = { + ignore = { "131" }; + globals = { + "Host", + "host", + "VirtualHost", + "Component", + "component", + "Include", + "include", + "RunScript" }; } + +if os.getenv("PROSODY_STRICT_LINT") ~= "1" then + -- These files have not yet been brought up to standard + -- Do not add more files here, but do help us fix these! + local exclude_files = { + "doc/net.server.lua"; + + "fallbacks/bit.lua"; + "fallbacks/lxp.lua"; + + "net/adns.lua"; + "net/cqueues.lua"; + "net/dns.lua"; + "net/server_select.lua"; + + "plugins/mod_admin_adhoc.lua"; + "plugins/mod_admin_telnet.lua"; + "plugins/mod_announce.lua"; + "plugins/mod_bosh.lua"; + "plugins/mod_groups.lua"; + "plugins/mod_http_files.lua"; + "plugins/mod_http.lua"; + "plugins/mod_legacyauth.lua"; + "plugins/mod_net_multiplex.lua"; + "plugins/mod_pep.lua"; + "plugins/mod_pep_plus.lua"; + "plugins/mod_privacy.lua"; + "plugins/mod_s2s/mod_s2s.lua"; + "plugins/mod_s2s/s2sout.lib.lua"; + "plugins/mod_storage_sql1.lua"; + "plugins/mod_storage_sql.lua"; + "plugins/mod_websocket.lua"; + + "spec/core_configmanager_spec.lua"; + "spec/core_moduleapi_spec.lua"; + "spec/net_http_parser_spec.lua"; + "spec/util_cache_spec.lua"; + "spec/util_events_spec.lua"; + "spec/util_http_spec.lua"; + "spec/util_ip_spec.lua"; + "spec/util_json_spec.lua"; + "spec/util_multitable_spec.lua"; + "spec/util_rfc6724_spec.lua"; + "spec/util_throttle_spec.lua"; + "spec/util_xmppstream_spec.lua"; + + "tools/ejabberd2prosody.lua"; + "tools/ejabberdsql2prosody.lua"; + "tools/erlparse.lua"; + "tools/jabberd14sql2prosody.lua"; + "tools/migration/migrator.cfg.lua"; + "tools/migration/migrator/jabberd14.lua"; + "tools/migration/migrator/mtools.lua"; + "tools/migration/migrator/prosody_files.lua"; + "tools/migration/migrator/prosody_sql.lua"; + "tools/migration/prosody-migrator.lua"; + "tools/openfire2prosody.lua"; + "tools/xep227toprosody.lua"; + + "util/sasl/digest-md5.lua"; + } + for _, file in ipairs(exclude_files) do + files[file] = { only = {} } + end +else + max_cyclomatic_complexity = 50 + max_line_length = 120 +end @@ -1,3 +1,21 @@ +trunk +===== + +**YYYY-MM-DD** + +New features +------------ + +- Rewritten more extensible MUC module + - Store inactive rooms to disk + - Store rooms to disk on shutdown + - Voice requests +- mod\_pep\_plus +- Persistence of PubSub data (applies to mod\_pubsub and mod\_pep\_plus) +- Asynchronous operations +- Busted for tests +- mod\_muc\_mam (XEP-0313 in groupchats) + 0.10.0 ====== @@ -1,6 +1,6 @@ For full information on our dependencies, version requirements, and -where to find them, see http://prosody.im/doc/depends +where to find them, see https://prosody.im/doc/depends If you have luarocks available on your platform, install the following: @@ -19,7 +19,7 @@ INSTALL_EXEC=$(INSTALL) -m755 MKDIR=install -d MKDIR_PRIVATE=$(MKDIR) -m750 -.PHONY: all test clean install +.PHONY: all test coverage clean install all: prosody.install prosodyctl.install prosody.cfg.lua.install prosody.version $(MAKE) -C util-src install @@ -37,8 +37,9 @@ install: prosody.install prosodyctl.install prosody.cfg.lua.install util/encodin $(INSTALL_EXEC) ./prosodyctl.install $(BIN)/prosodyctl $(INSTALL_DATA) core/*.lua $(SOURCE)/core $(INSTALL_DATA) net/*.lua $(SOURCE)/net - $(MKDIR) $(SOURCE)/net/http $(SOURCE)/net/websocket + $(MKDIR) $(SOURCE)/net/http $(SOURCE)/net/resolvers $(SOURCE)/net/websocket $(INSTALL_DATA) net/http/*.lua $(SOURCE)/net/http + $(INSTALL_DATA) net/resolvers/*.lua $(SOURCE)/net/resolvers $(INSTALL_DATA) net/websocket/*.lua $(SOURCE)/net/websocket $(INSTALL_DATA) util/*.lua $(SOURCE)/util $(INSTALL_DATA) util/*.so $(SOURCE)/util @@ -65,8 +66,19 @@ clean: $(MAKE) clean -C util-src test: - cd tests && $(RUNWITH) test.lua 0 - # Skipping: cd tests && RUNWITH=$(RUNWITH) ./test_util_json.sh + busted --lua=$(RUNWITH) + +coverage: + -rm -- luacov.* + busted --lua=$(RUNWITH) -c + luacov + luacov-console + luacov-console -s + @echo "To inspect individual files run: luacov-console -l FILENAME" + +lint: + luacheck -q $$(hg files -I '**.lua') prosody prosodyctl + @echo $$(sed -n '/^exclude_files/,/^}/p;' .luacheckrc | sed '1d;$d' | wc -l) files ignored util/%.so: $(MAKE) install -C util-src @@ -2,11 +2,11 @@ Welcome hackers! This project accepts and *encourages* contributions. If you would like to get involved you can join us on our mailing list and discussion rooms. More -information on these at http://prosody.im/discuss +information on these at https://prosody.im/discuss Patches are welcome, though before sending we would appreciate if you read docs/coding_style.txt for guidelines on how to format your code, and other tips. -Documentation for developers can be found at http://prosody.im/doc/developers +Documentation for developers can be found at https://prosody.im/doc/developers Have fun :) @@ -1,5 +1,5 @@ (This file was created from -http://prosody.im/doc/installing_from_source on 2013-03-31) +https://prosody.im/doc/installing_from_source on 2013-03-31) ====== Installing from source ====== ==== Dependencies ==== @@ -9,29 +9,29 @@ rapidly prototype new protocols. ## Useful links -Homepage: http://prosody.im/ -Download: http://prosody.im/download -Documentation: http://prosody.im/doc/ +Homepage: https://prosody.im/ +Download: https://prosody.im/download +Documentation: https://prosody.im/doc/ Jabber/XMPP Chat: Address: prosody@conference.prosody.im Web interface: - http://prosody.im/webchat + https://prosody.im/webchat Mailing lists: User support and discussion: - http://groups.google.com/group/prosody-users + https://groups.google.com/group/prosody-users Development discussion: - http://groups.google.com/group/prosody-dev + https://groups.google.com/group/prosody-dev Issue tracker changes: - http://groups.google.com/group/prosody-issues + https://groups.google.com/group/prosody-issues ## Installation See the accompanying INSTALL file for help on building Prosody from source. Alternatively -see our guide at http://prosody.im/doc/install +see our guide at https://prosody.im/doc/install diff --git a/certs/Makefile b/certs/GNUmakefile index fd4a2932..fd4a2932 100644 --- a/certs/Makefile +++ b/certs/GNUmakefile diff --git a/certs/localhost.cnf b/certs/localhost.cnf index a7dc6cfe..61d59e72 100644 --- a/certs/localhost.cnf +++ b/certs/localhost.cnf @@ -11,7 +11,7 @@ otherName.2 = 1.3.6.1.5.5.7.8.5;FORMAT:UTF8,UTF8:localhost [distinguished_name] countryName = GB organizationName = Prosody IM -organizationalUnitName = http://prosody.im/doc/certificates +organizationalUnitName = https://prosody.im/doc/certificates commonName = Example certificate [req] diff --git a/certs/makefile b/certs/makefile new file mode 100644 index 00000000..b0614072 --- /dev/null +++ b/certs/makefile @@ -0,0 +1,18 @@ +.DEFAULT: localhost.crt +keysize=2048 + +# How to: +# First, `make yourhost.cnf` which creates a openssl config file. +# Then edit this file and fill in the details you want it to have, +# and add or change hosts and components it should cover. +# Then `make yourhost.key` to create your private key, you can +# include keysize=number to change the size of the key. +# Then you can either `make yourhost.csr` to generate a certificate +# signing request that you can submit to a CA, or `make yourhost.crt` +# to generate a self signed certificate. + +${.TARGETS:M*.crt}: + openssl req -new -x509 -newkey rsa:$(keysize) -nodes -keyout ${.TARGET:R}.key \ + -days 365 -sha256 -out $@ -utf8 -subj /CN=${.TARGET:R} + +.SUFFIXES: .key .crt @@ -528,6 +528,8 @@ OPENSSL_LIBS="-l$OPENSSL_LIB" if [ "$PRNG" = "OPENSSL" ]; then PRNGLIBS=$OPENSSL_LIBS +elif [ "$PRNG" = "ARC4RANDOM" -a "$(uname)" = "Linux" ]; then + PRNGLIBS="-lbsd" fi # Write config diff --git a/core/certmanager.lua b/core/certmanager.lua index dac4baa4..5282a6f5 100644 --- a/core/certmanager.lua +++ b/core/certmanager.lua @@ -55,6 +55,7 @@ local luasec_has = softreq"ssl.config" or { }; local _ENV = nil; +-- luacheck: std none -- Global SSL options if not overridden per-host local global_ssl_config = configmanager.get("*", "ssl"); diff --git a/core/configmanager.lua b/core/configmanager.lua index 5a544375..d9482b81 100644 --- a/core/configmanager.lua +++ b/core/configmanager.lua @@ -7,12 +7,10 @@ -- local _G = _G; -local setmetatable, rawget, rawset, io, error, dofile, type, pairs, table = - setmetatable, rawget, rawset, io, error, dofile, type, pairs, table; +local setmetatable, rawget, rawset, io, error, dofile, type, pairs = + setmetatable, rawget, rawset, io, error, dofile, type, pairs; local format, math_max = string.format, math.max; -local fire_event = prosody and prosody.events.fire_event or function () end; - local envload = require"util.envload".envload; local deps = require"util.dependencies"; local resolve_relative_path = require"util.paths".resolve_relative_path; @@ -24,10 +22,11 @@ local nameprep = encodings and encodings.stringprep.nameprep or function (host) local _M = {}; local _ENV = nil; +-- luacheck: std none _M.resolve_relative_path = resolve_relative_path; -- COMPAT -local parsers = {}; +local parser = nil; local config_mt = { __index = function (t, _) return rawget(t, "*"); end}; local config = setmetatable({ ["*"] = { } }, config_mt); @@ -77,19 +76,14 @@ end function _M.load(filename, config_format) config_format = config_format or filename:match("%w+$"); - if parsers[config_format] and parsers[config_format].load then + if config_format == "lua" then local f, err = io.open(filename); if f then local new_config = setmetatable({ ["*"] = { } }, config_mt); - local ok, err = parsers[config_format].load(f:read("*a"), filename, new_config); + local ok, err = parser.load(f:read("*a"), filename, new_config); f:close(); if ok then config = new_config; - fire_event("config-reloaded", { - filename = filename, - format = config_format, - config = config - }); end return ok, "parser", err; end @@ -103,26 +97,11 @@ function _M.load(filename, config_format) end end -function _M.addparser(config_format, parser) - if config_format and parser then - parsers[config_format] = parser; - end -end - --- _M needed to avoid name clash with local 'parsers' -function _M.parsers() - local p = {}; - for config_format in pairs(parsers) do - table.insert(p, config_format); - end - return p; -end - -- Built-in Lua parser do local pcall = _G.pcall; - parsers.lua = {}; - function parsers.lua.load(data, config_file, config_table) + parser = {}; + function parser.load(data, config_file, config_table) local env; -- The ' = true' are needed so as not to set off __newindex when we assign the functions below env = setmetatable({ @@ -211,7 +190,7 @@ do file = resolve_relative_path(config_file:gsub("[^"..path_sep.."]+$", ""), file); local f, err = io.open(file); if f then - local ret, err = parsers.lua.load(f:read("*a"), file, config_table); + local ret, err = parser.load(f:read("*a"), file, config_table); if not ret then error(err:gsub("%[string.-%]", file), 0); end end if not f then error("Error loading included "..file..": "..err, 0); end diff --git a/core/hostmanager.lua b/core/hostmanager.lua index 106a8ef2..9acca517 100644 --- a/core/hostmanager.lua +++ b/core/hostmanager.lua @@ -12,8 +12,6 @@ local events_new = require "util.events".new; local disco_items = require "util.multitable".new(); local NULL = {}; -local jid_split = require "util.jid".split; - local log = require "util.logger".init("hostmanager"); local hosts = prosody.hosts; @@ -24,11 +22,12 @@ end local incoming_s2s = _G.prosody.incoming_s2s; local core_route_stanza = _G.prosody.core_route_stanza; -local pairs, select, rawget = pairs, select, rawget; +local pairs, rawget = pairs, rawget; local tostring, type = tostring, type; local setmetatable = setmetatable; local _ENV = nil; +-- luacheck: std none local host_mt = { } function host_mt:__tostring() @@ -71,13 +70,6 @@ end prosody_events.add_handler("server-starting", load_enabled_hosts); local function host_send(stanza) - local name, stanza_type = stanza.name, stanza.attr.type; - if stanza_type == "error" or (name == "iq" and stanza_type == "result") then - local dest_host_name = select(2, jid_split(stanza.attr.to)); - local dest_host = hosts[dest_host_name] or { type = "unknown" }; - log("warn", "Unhandled response sent to %s host %s: %s", dest_host.type, dest_host_name, tostring(stanza)); - return; - end core_route_stanza(nil, stanza); end diff --git a/core/loggingmanager.lua b/core/loggingmanager.lua index 004f4c3b..cfa8246a 100644 --- a/core/loggingmanager.lua +++ b/core/loggingmanager.lua @@ -5,7 +5,6 @@ -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- --- luacheck: globals log prosody.log local format = require "util.format".format; local setmetatable, rawset, pairs, ipairs, type = @@ -18,12 +17,9 @@ local getstyle, getstring = require "util.termcolours".getstyle, require "util.t local config = require "core.configmanager"; local logger = require "util.logger"; -local prosody = prosody; - -_G.log = logger.init("general"); -prosody.log = logger.init("general"); local _ENV = nil; +-- luacheck: std none -- The log config used if none specified in the config file (see reload_logging for initialization) local default_logging; @@ -154,13 +150,8 @@ local function reload_logging() for name, sink_maker in pairs(old_sink_types) do log_sink_types[name] = sink_maker; end - - prosody.events.fire_event("logging-reloaded"); end -reload_logging(); -prosody.events.add_handler("config-reloaded", reload_logging); - --- Definition of built-in logging sinks --- -- Null sink, must enter log_sink_types *first* diff --git a/core/moduleapi.lua b/core/moduleapi.lua index 7d954c1f..64c4ce43 100644 --- a/core/moduleapi.lua +++ b/core/moduleapi.lua @@ -20,9 +20,10 @@ local st = require "util.stanza"; local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; local error, setmetatable, type = error, setmetatable, type; local ipairs, pairs, select = ipairs, pairs, select; -local unpack = table.unpack or unpack; --luacheck: ignore 113 local tonumber, tostring = tonumber, tostring; local require = require; +local pack = table.pack or function(...) return {n=select("#",...), ...}; end -- table.pack is only in 5.2 +local unpack = table.unpack or unpack; --luacheck: ignore 113 -- renamed in 5.2 local prosody = prosody; local hosts = prosody.hosts; @@ -70,20 +71,6 @@ end function api:add_extension(data) self:add_item("extension", data); end -function api:has_feature(xmlns) - for _, feature in ipairs(self:get_host_items("feature")) do - if feature == xmlns then return true; end - end - return false; -end -function api:has_identity(category, identity_type, name) - for _, id in ipairs(self:get_host_items("identity")) do - if id.category == category and id.type == identity_type and id.name == name then - return true; - end - end - return false; -end function api:fire_event(...) return (hosts[self.host] or prosody).events.fire_event(...); @@ -160,6 +147,9 @@ function api:depends(name) end end); end + if self:get_option_inherited_set("modules_disabled", {}):contains(name) then + self:log("warn", "Loading prerequisite mod_%s despite it being disabled", name); + end local mod = modulemanager.get_module(self.host, name) or modulemanager.get_module("*", name); if mod and mod.module.host == "*" and self.host ~= "*" and modulemanager.module_has_method(mod, "add_host") then @@ -381,11 +371,29 @@ function api:broadcast(jids, stanza, iter) end end -function api:add_timer(delay, callback) - return timer.add_task(delay, function (t) - if self.loaded == false then return; end - return callback(t); - end); +local timer_methods = { } +local timer_mt = { + __index = timer_methods; +} +function timer_methods:stop( ) + timer.stop(self.id); +end +timer_methods.disarm = timer_methods.stop +function timer_methods:reschedule(delay) + timer.reschedule(self.id, delay) +end + +local function timer_callback(now, id, t) --luacheck: ignore 212/id + if t.module_env.loaded == false then return; end + return t.callback(now, unpack(t, 1, t.n)); +end + +function api:add_timer(delay, callback, ...) + local t = pack(...) + t.module_env = self; + t.callback = callback; + t.id = timer.add_task(delay, timer_callback, t); + return setmetatable(t, timer_mt); end local path_sep = package.config:sub(1,1); diff --git a/core/modulemanager.lua b/core/modulemanager.lua index 771f6fb1..81c28aa0 100644 --- a/core/modulemanager.lua +++ b/core/modulemanager.lua @@ -15,8 +15,8 @@ local set = require "util.set"; local new_multitable = require "util.multitable".new; local api = require "core.moduleapi"; -- Module API container -local hosts = hosts; local prosody = prosody; +local hosts = prosody.hosts; local xpcall = xpcall; local setmetatable, rawget = setmetatable, rawget; @@ -38,6 +38,7 @@ local component_inheritable_modules = {"tls", "saslauth", "dialback", "iq", "s2s local _G = _G; local _ENV = nil; +-- luacheck: std none local load_modules_for_host, load, unload, reload, get_module, get_items; local get_modules, is_loaded, module_has_method, call_module_method; @@ -45,8 +46,8 @@ local get_modules, is_loaded, module_has_method, call_module_method; -- [host] = { [module] = module_env } local modulemap = { ["*"] = {} }; --- Load modules when a host is activated -function load_modules_for_host(host) +-- Get the list of modules to be loaded on a host +local function get_modules_for_host(host) local component = config.get(host, "component_module"); local global_modules_enabled = config.get("*", "modules_enabled"); @@ -70,8 +71,16 @@ function load_modules_for_host(host) modules:add("admin_telnet"); end - if component then - load(host, component); + return modules, component; +end + +-- Load modules when a host is activated +function load_modules_for_host(host) + local modules, component_module = get_modules_for_host(host); + + -- Ensure component module is loaded first + if component_module then + load(host, component_module); end for module in modules do load(host, module); @@ -323,6 +332,7 @@ function call_module_method(module, method, ...) end return { + get_modules_for_host = get_modules_for_host; load_modules_for_host = load_modules_for_host; load = load; unload = unload; diff --git a/core/portmanager.lua b/core/portmanager.lua index 5b6476f3..1ed37da0 100644 --- a/core/portmanager.lua +++ b/core/portmanager.lua @@ -15,6 +15,7 @@ local prosody = prosody; local fire_event = prosody.events.fire_event; local _ENV = nil; +-- luacheck: std none --- Config diff --git a/core/rostermanager.lua b/core/rostermanager.lua index 65be0de0..aa1ba9f3 100644 --- a/core/rostermanager.lua +++ b/core/rostermanager.lua @@ -15,7 +15,7 @@ local pairs = pairs; local tostring = tostring; local type = type; -local hosts = hosts; +local hosts = prosody.hosts; local bare_sessions = prosody.bare_sessions; local um_user_exists = require "core.usermanager".user_exists; @@ -23,6 +23,7 @@ local st = require "util.stanza"; local storagemanager = require "core.storagemanager"; local _ENV = nil; +-- luacheck: std none local save_roster; -- forward declaration diff --git a/core/s2smanager.lua b/core/s2smanager.lua index d84572f3..58269c49 100644 --- a/core/s2smanager.lua +++ b/core/s2smanager.lua @@ -17,12 +17,13 @@ local logger_init = require "util.logger".init; local log = logger_init("s2smanager"); local prosody = _G.prosody; -incoming_s2s = {}; +local incoming_s2s = {}; +_G.incoming_s2s = incoming_s2s; prosody.incoming_s2s = incoming_s2s; -local incoming_s2s = incoming_s2s; local fire_event = prosody.events.fire_event; local _ENV = nil; +-- luacheck: std none local function new_incoming(conn) local session = { conn = conn, type = "s2sin_unauthed", direction = "incoming", hosts = {} }; @@ -64,6 +65,7 @@ local function retire_session(session, reason) function session.send(data) log("debug", "Discarding data sent to resting session: %s", tostring(data)); end function session.data(data) log("debug", "Discarding data received from resting session: %s", tostring(data)); end + session.thread = { run = function (_, data) return session.data(data) end }; session.sends2s = session.send; return setmetatable(session, resting_session); end diff --git a/core/sessionmanager.lua b/core/sessionmanager.lua index 6c9ecc24..2b429df9 100644 --- a/core/sessionmanager.lua +++ b/core/sessionmanager.lua @@ -10,7 +10,7 @@ local tostring, setmetatable = tostring, setmetatable; local pairs, next= pairs, next; -local hosts = hosts; +local hosts = prosody.hosts; local full_sessions = prosody.full_sessions; local bare_sessions = prosody.bare_sessions; @@ -20,12 +20,13 @@ local rm_load_roster = require "core.rostermanager".load_roster; local config_get = require "core.configmanager".get; local resourceprep = require "util.encodings".stringprep.resourceprep; local nodeprep = require "util.encodings".stringprep.nodeprep; -local uuid_generate = require "util.uuid".generate; +local generate_identifier = require "util.id".short; local initialize_filters = require "util.filters".initialize; local gettime = require "socket".gettime; local _ENV = nil; +-- luacheck: std none local function new_session(conn) local session = { conn = conn, type = "c2s_unauthed", conntime = gettime() }; @@ -73,6 +74,7 @@ local function retire_session(session) function session.send(data) log("debug", "Discarding data sent to resting session: %s", tostring(data)); return false; end function session.data(data) log("debug", "Discarding data received from resting session: %s", tostring(data)); end + session.thread = { run = function (_, data) return session.data(data) end }; return setmetatable(session, resting_session); end @@ -136,7 +138,7 @@ local function bind_resource(session, resource) end resource = resourceprep(resource); - resource = resource ~= "" and resource or uuid_generate(); + resource = resource ~= "" and resource or generate_identifier(); --FIXME: Randomly-generated resources must be unique per-user, and never conflict with existing if not hosts[session.host].sessions[session.username] then @@ -150,7 +152,7 @@ local function bind_resource(session, resource) local policy = config_get(session.host, "conflict_resolve"); local increment; if policy == "random" then - resource = uuid_generate(); + resource = generate_identifier(); increment = true; elseif policy == "increment" then increment = true; -- TODO ping old resource diff --git a/core/stanza_router.lua b/core/stanza_router.lua index 0be92f88..49fe747c 100644 --- a/core/stanza_router.lua +++ b/core/stanza_router.lua @@ -19,7 +19,7 @@ local bare_sessions = _G.prosody.bare_sessions; local core_post_stanza, core_process_stanza, core_route_stanza; -function deprecated_warning(f) +local function deprecated_warning(f) _G[f] = function(...) log("warn", "Using the global %s() is deprecated, use module:send() or prosody.%s(). %s", f, f, debug.traceback()); return prosody[f](...); diff --git a/core/storagemanager.lua b/core/storagemanager.lua index c93438af..dea71733 100644 --- a/core/storagemanager.lua +++ b/core/storagemanager.lua @@ -1,17 +1,21 @@ local type, pairs = type, pairs; local setmetatable = setmetatable; +local rawset = rawset; local config = require "core.configmanager"; local datamanager = require "util.datamanager"; local modulemanager = require "core.modulemanager"; local multitable = require "util.multitable"; -local hosts = hosts; local log = require "util.logger".init("storagemanager"); +local async = require "util.async"; +local debug = debug; local prosody = prosody; +local hosts = prosody.hosts; local _ENV = nil; +-- luacheck: std none local olddm = {}; -- maintain old datamanager, for backwards compatibility for k,v in pairs(datamanager) do olddm[k] = v; end @@ -28,8 +32,34 @@ local null_storage_driver = setmetatable( } ); +local async_check = config.get("*", "storage_async_check") == true; + local stores_available = multitable.new(); +local function check_async_wrapper(event) + local store = event.store; + event.store = setmetatable({}, { + __index = function (t, method_name) + local original_method = store[method_name]; + if type(original_method) ~= "function" then + if original_method then + rawset(t, method_name, original_method); + end + return original_method; + end + local wrapped_method = function (...) + if not async.ready() then + log("warn", "ASYNC-01: Attempt to access storage outside async context, " + .."see https://prosody.im/doc/developers/async - %s", debug.traceback()); + end + return original_method(...); + end + rawset(t, method_name, wrapped_method); + return wrapped_method; + end; + }); +end + local function initialize_host(host) local host_session = hosts[host]; host_session.events.add_handler("item-added/storage-provider", function (event) @@ -41,6 +71,9 @@ local function initialize_host(host) local item = event.item; stores_available:set(host, item.name, nil); end); + if async_check then + host_session.events.add_handler("store-opened", check_async_wrapper); + end end prosody.events.add_handler("host-activated", initialize_host, 101); @@ -137,7 +170,7 @@ local map_shim_mt = { }; } -local open; +local open; -- forward declaration local function create_map_shim(host, store) local keyval_store, err = open(host, store, "keyval"); diff --git a/core/usermanager.lua b/core/usermanager.lua index f795e8ae..bb5669cf 100644 --- a/core/usermanager.lua +++ b/core/usermanager.lua @@ -13,17 +13,18 @@ local ipairs = ipairs; local jid_bare = require "util.jid".bare; local jid_prep = require "util.jid".prep; local config = require "core.configmanager"; -local hosts = hosts; local sasl_new = require "util.sasl".new; local storagemanager = require "core.storagemanager"; local prosody = _G.prosody; +local hosts = prosody.hosts; local setmetatable = setmetatable; local default_provider = "internal_plain"; local _ENV = nil; +-- luacheck: std none local function new_null_provider() local function dummy() return nil, "method not implemented"; end; diff --git a/doc/coding_style.txt b/doc/coding_style.txt index c9113e81..af44da1a 100644 --- a/doc/coding_style.txt +++ b/doc/coding_style.txt @@ -7,7 +7,7 @@ Please try to follow, and feel free to fix code you see not following this stand == Spacing == -No space between function names and parenthesis and parenthesis and paramters: +No space between function names and parenthesis and parenthesis and parameters: function foo(bar, baz) diff --git a/doc/names.txt b/doc/names.txt index 7a6ab1e9..7026c985 100644 --- a/doc/names.txt +++ b/doc/names.txt @@ -15,7 +15,7 @@ Thorns thought of: Eclaire - Idem (French) Adel - Random Younha - Read as "yuna" - Quezacotl - Mayan gods -> google for correct form and pronounciation + Quezacotl - Mayan gods -> google for correct form and pronunciation Carbuncle - FF8 Guardian Force ^^ Protos - Mars satellite mins - Derived from minstrel diff --git a/doc/net.server.lua b/doc/net.server.lua new file mode 100644 index 00000000..f07a2bd0 --- /dev/null +++ b/doc/net.server.lua @@ -0,0 +1,256 @@ +-- Prosody IM +-- Copyright (C) 2014,2016 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. + +--[[ +This file is a template for writing a net.server compatible backend. +]] + +--[[ +Read patterns (also called modes) can be one of: + - "*a": Read as much as possible + - "*l": Read until end of line +]] + +--- Handle API +local handle_mt = {}; +local handle_methods = {}; +handle_mt.__index = handle_methods; + +function handle_methods:set_mode(new_pattern) +end + +function handle_methods:setlistener(listeners) +end + +function handle_methods:setoption(option, value) +end + +function handle_methods:ip() +end + +function handle_methods:starttls(sslctx) +end + +function handle_methods:write(data) +end + +function handle_methods:close() +end + +function handle_methods:pause() +end + +function handle_methods:resume() +end + +--[[ +Returns + - socket: the socket object underlying this handle +]] +function handle_methods:socket() +end + +--[[ +Returns + - boolean: if an ssl context has been set on this handle +]] +function handle_methods:ssl() +end + + +--- Listeners API +local listeners = {} + +--[[ connect +Called when a client socket has established a connection with it's peer +]] +function listeners.onconnect(handle) +end + +--[[ incoming +Called when data is received +If reading data failed this will be called with `nil, "error message"` +]] +function listeners.onincoming(handle, buff, err) +end + +--[[ status +Known statuses: + - "ssl-handshake-complete" +]] +function listeners.onstatus(handle, status) +end + +--[[ disconnect +Called when the peer has closed the connection +]] +function listeners.ondisconnect(handle) +end + +--[[ drain +Called when the handle's write buffer is empty +]] +function listeners.ondrain(handle) +end + +--[[ readtimeout +Called when a socket inactivity timeout occurs +]] +function listeners.onreadtimeout(handle) +end + +--[[ detach: Called when other listeners are going to be removed +Allows for clean-up +]] +function listeners.ondetach(handle) +end + +--- Top level functions + +--[[ Returns the syscall level event mechanism in use. + +Returns: + - backend: e.g. "select", "epoll" +]] +local function get_backend() +end + +--[[ Starts the event loop. + +Returns: + - "quitting" +]] +local function loop() +end + +--[[ Stop a running loop() +]] +local function setquitting(quit) +end + + +--[[ Links to two handles together, so anything written to one is piped to the other + +Arguments: + - sender, receiver: handles to link + - buffersize: maximum #bytes until sender will be locked +]] +local function link(sender, receiver, buffersize) +end + +--[[ Binds and listens on the given address and port +If `sslctx` is given, the connecting clients will have to negotiate an SSL session + +Arguments: + - address: address to bind to, may be "*" to bind all addresses. will be resolved if it is a string. + - port: port to bind (as number) + - listeners: a table of listeners + - pattern: the read pattern + - sslctx: is a valid luasec constructor + +Returns: + - handle + - nil, "an error message": on failure (e.g. out of file descriptors) +]] +local function addserver(address, port, listeners, pattern, sslctx) +end + +--[[ Wraps a lua-socket socket client socket in a handle. +The socket must be already connected to the remote end. +If `sslctx` is given, a SSL session will be negotiated before listeners are called. + +Arguments: + - socket: the lua-socket object to wrap + - ip: returned by `handle:ip()` + - port: + - listeners: a table of listeners + - pattern: the read pattern + - sslctx: is a valid luasec constructor + - typ: the socket type, one of: + - "tcp" + - "tcp6" + - "udp" + +Returns: + - handle, socket + - nil, "an error message": on failure (e.g. ) +]] +local function wrapclient(socket, ip, serverport, listeners, pattern, sslctx) +end + +--[[ Connects to the given address and port +If `sslctx` is given, a SSL session will be negotiated before listeners are called. + +Arguments: + - address: address to connect to. will be resolved if it is a string. + - port: port to connect to (as number) + - listeners: a table of listeners + - pattern: the read pattern + - sslctx: is a valid luasec constructor + - typ: the socket type, one of: + - "tcp" + - "tcp6" + - "udp" + +Returns: + - handle + - nil, "an error message": on failure (e.g. out of file descriptors) +]] +local function addclient(address, port, listeners, pattern, sslctx, typ) +end + +--[[ Close all handles +]] +local function closeall() +end + +--[[ The callback should be called after `delay` seconds. +The callback should be called with the time at the point of firing. +If the callback returns a number, it should be called again after that many seconds. + +Arguments: + - delay: number of seconds to wait + - callback: function to call. +]] +local function add_task(delay, callback) +end + +--[[ Adds a handler for when a signal is fired. +Optional to implement +callback does not take any arguments + +Arguments: + - signal_id: the signal id (as number) to listen for + - handler: callback +]] +local function hook_signal(signal_id, handler) +end + +--[[ Adds a low-level FD watcher +Arguments: +- fd_number: A non-negative integer representing a file descriptor or + object with a :getfd() method returning one +- on_readable: Optional callback for when the FD is readable +- on_writable: Optional callback for when the FD is writable + +Returns: +- net.server handle +]] +local function watchfd(fd_number, on_readable, on_writable) +end + +return { + get_backend = get_backend; + loop = loop; + setquitting = setquitting; + link = link; + addserver = addserver; + wrapclient = wrapclient; + addclient = addclient; + closeall = closeall; + hook_signal = hook_signal; + watchfd = watchfd; +} diff --git a/doc/session.txt b/doc/session.txt index fc6eec17..c1f99947 100644 --- a/doc/session.txt +++ b/doc/session.txt @@ -15,12 +15,12 @@ session { full_jid -- convenience for the above 3 as string in username@host/resource form (not defined before resource binding)
priority -- the resource priority, default: 0
presence -- the last non-directed presence with no type attribute. initially nil. reset to nil on unavailable presence.
- interested -- true if the resource requested the roster. Interested resources recieve roster updates. Initially nil.
+ interested -- true if the resource requested the roster. Interested resources receive roster updates. Initially nil.
roster -- the user's roster. Loaded as soon as the resource is bound (session becomes a connected resource).
-- methods --
send(x) -- converts x to a string, and writes it to the connection
- disconnect(x) -- Disconnect the user and clean up the session, best call sessionmanager.destroy_session() instead of this in most cases
+ close(x) -- Disconnect the user and clean up the session, best call sessionmanager.destroy_session() instead of this in most cases
}
if session.full_jid (also session.roster and session.resource) then this is a "connected resource"
diff --git a/makefile b/makefile new file mode 100644 index 00000000..d19ec24d --- /dev/null +++ b/makefile @@ -0,0 +1,101 @@ + +include config.unix + +BIN = $(DESTDIR)$(PREFIX)/bin +CONFIG = $(DESTDIR)$(SYSCONFDIR) +MODULES = $(DESTDIR)$(LIBDIR)/prosody/modules +SOURCE = $(DESTDIR)$(LIBDIR)/prosody +DATA = $(DESTDIR)$(DATADIR) +MAN = $(DESTDIR)$(PREFIX)/share/man + +INSTALLEDSOURCE = $(LIBDIR)/prosody +INSTALLEDCONFIG = $(SYSCONFDIR) +INSTALLEDMODULES = $(LIBDIR)/prosody/modules +INSTALLEDDATA = $(DATADIR) + +INSTALL=install -p +INSTALL_DATA=$(INSTALL) -m644 +INSTALL_EXEC=$(INSTALL) -m755 +MKDIR=install -d +MKDIR_PRIVATE=$(MKDIR) -m750 + +.PHONY: all test clean install + +all: prosody.install prosodyctl.install prosody.cfg.lua.install prosody.version + $(MAKE) -C util-src install +.if $(EXCERTS) == "yes" + $(MAKE) -C certs localhost.crt example.com.crt +.endif + +install: prosody.install prosodyctl.install prosody.cfg.lua.install util/encodings.so util/encodings.so util/pposix.so util/signal.so + $(MKDIR) $(BIN) $(CONFIG) $(MODULES) $(SOURCE) + $(MKDIR_PRIVATE) $(DATA) + $(MKDIR) $(MAN)/man1 + $(MKDIR) $(CONFIG)/certs + $(MKDIR) $(SOURCE)/core $(SOURCE)/net $(SOURCE)/util + $(INSTALL_EXEC) ./prosody.install $(BIN)/prosody + $(INSTALL_EXEC) ./prosodyctl.install $(BIN)/prosodyctl + $(INSTALL_DATA) core/*.lua $(SOURCE)/core + $(INSTALL_DATA) net/*.lua $(SOURCE)/net + $(MKDIR) $(SOURCE)/net/http $(SOURCE)/net/resolvers $(SOURCE)/net/websocket + $(INSTALL_DATA) net/http/*.lua $(SOURCE)/net/http + $(INSTALL_DATA) net/resolvers/*.lua $(SOURCE)/net/resolvers + $(INSTALL_DATA) net/websocket/*.lua $(SOURCE)/net/websocket + $(INSTALL_DATA) util/*.lua $(SOURCE)/util + $(INSTALL_DATA) util/*.so $(SOURCE)/util + $(MKDIR) $(SOURCE)/util/sasl + $(INSTALL_DATA) util/sasl/*.lua $(SOURCE)/util/sasl + $(MKDIR) $(MODULES)/mod_s2s $(MODULES)/mod_pubsub $(MODULES)/adhoc $(MODULES)/muc $(MODULES)/mod_mam + $(INSTALL_DATA) plugins/*.lua $(MODULES) + $(INSTALL_DATA) plugins/mod_s2s/*.lua $(MODULES)/mod_s2s + $(INSTALL_DATA) plugins/mod_pubsub/*.lua $(MODULES)/mod_pubsub + $(INSTALL_DATA) plugins/adhoc/*.lua $(MODULES)/adhoc + $(INSTALL_DATA) plugins/muc/*.lua $(MODULES)/muc + $(INSTALL_DATA) plugins/mod_mam/*.lua $(MODULES)/mod_mam +.if $(EXCERTS) == "yes" + $(INSTALL_DATA) certs/localhost.crt certs/localhost.key $(CONFIG)/certs + $(INSTALL_DATA) certs/example.com.crt certs/example.com.key $(CONFIG)/certs +.endif + $(INSTALL_DATA) man/prosodyctl.man $(MAN)/man1/prosodyctl.1 + test -f $(CONFIG)/prosody.cfg.lua || $(INSTALL_DATA) prosody.cfg.lua.install $(CONFIG)/prosody.cfg.lua + -test -f prosody.version && $(INSTALL_DATA) prosody.version $(SOURCE)/prosody.version + $(MAKE) install -C util-src + +clean: + rm -f prosody.install + rm -f prosodyctl.install + rm -f prosody.cfg.lua.install + rm -f prosody.version + $(MAKE) clean -C util-src + +test: + busted --lua=$(RUNWITH) + + +prosody.install: prosody + sed "1s| lua$$| $(RUNWITH)|; \ + s|^CFG_SOURCEDIR=.*;$$|CFG_SOURCEDIR='$(INSTALLEDSOURCE)';|; \ + s|^CFG_CONFIGDIR=.*;$$|CFG_CONFIGDIR='$(INSTALLEDCONFIG)';|; \ + s|^CFG_DATADIR=.*;$$|CFG_DATADIR='$(INSTALLEDDATA)';|; \ + s|^CFG_PLUGINDIR=.*;$$|CFG_PLUGINDIR='$(INSTALLEDMODULES)/';|;" < prosody > $@ + +prosodyctl.install: prosodyctl + sed "1s| lua$$| $(RUNWITH)|; \ + s|^CFG_SOURCEDIR=.*;$$|CFG_SOURCEDIR='$(INSTALLEDSOURCE)';|; \ + s|^CFG_CONFIGDIR=.*;$$|CFG_CONFIGDIR='$(INSTALLEDCONFIG)';|; \ + s|^CFG_DATADIR=.*;$$|CFG_DATADIR='$(INSTALLEDDATA)';|; \ + s|^CFG_PLUGINDIR=.*;$$|CFG_PLUGINDIR='$(INSTALLEDMODULES)/';|;" < prosodyctl > $@ + +prosody.cfg.lua.install: prosody.cfg.lua.dist + sed 's|certs/|$(INSTALLEDCONFIG)/certs/|' prosody.cfg.lua.dist > $@ + +prosody.version: + test -f prosody.release && \ + cp prosody.release $@ || \ + test -f .hg_archival.txt && \ + sed -n 's/^node: \(............\).*/\1/p' .hg_archival.txt > $@ || \ + test -f .hg/dirstate && \ + hexdump -n6 -e'6/1 "%02x"' .hg/dirstate > $@ || \ + echo unknown > $@ + + diff --git a/net/adns.lua b/net/adns.lua index a19cbd59..560e4b53 100644 --- a/net/adns.lua +++ b/net/adns.lua @@ -17,6 +17,7 @@ local setmetatable = setmetatable; local function dummy_send(sock, data, i, j) return (j-i)+1; end local _ENV = nil; +-- luacheck: std none local async_resolver_methods = {}; local async_resolver_mt = { __index = async_resolver_methods }; diff --git a/net/connect.lua b/net/connect.lua new file mode 100644 index 00000000..02ab4cc0 --- /dev/null +++ b/net/connect.lua @@ -0,0 +1,92 @@ +local server = require "net.server"; +local log = require "util.logger".init("net.connect"); +local new_id = require "util.id".short; + +local pending_connection_methods = {}; +local pending_connection_mt = { + __name = "pending_connection"; + __index = pending_connection_methods; + __tostring = function (p) + return "<pending connection "..p.id.." to "..tostring(p.target_resolver.hostname)..">"; + end; +}; + +function pending_connection_methods:log(level, message, ...) + log(level, "[pending connection %s] "..message, self.id, ...); +end + +-- pending_connections_map[conn] = pending_connection +local pending_connections_map = {}; + +local pending_connection_listeners = {}; + +local function attempt_connection(p) + p:log("debug", "Checking for targets..."); + if p.conn then + pending_connections_map[p.conn] = nil; + p.conn = nil; + end + p.target_resolver:next(function (conn_type, ip, port, extra) + if not conn_type then + -- No more targets to try + p:log("debug", "No more connection targets to try"); + if p.listeners.onfail then + p.listeners.onfail(p.data, p.last_error or "unable to resolve service"); + end + return; + end + p:log("debug", "Next target to try is %s:%d", ip, port); + local conn, err = server.addclient(ip, port, pending_connection_listeners, p.options.pattern or "*a", p.options.sslctx, conn_type, extra); + if not conn then + log("debug", "Connection attempt failed immediately: %s", tostring(err)); + p.last_error = err or "unknown reason"; + return attempt_connection(p); + end + p.conn = conn; + pending_connections_map[conn] = p; + end); +end + +function pending_connection_listeners.onconnect(conn) + local p = pending_connections_map[conn]; + if not p then + log("warn", "Successful connection, but unexpected! Closing."); + conn:close(); + return; + end + pending_connections_map[conn] = nil; + p:log("debug", "Successfully connected"); + if p.listeners.onattach then + p.listeners.onattach(conn, p.data); + end + conn:setlistener(p.listeners); + return p.listeners.onconnect(conn); +end + +function pending_connection_listeners.ondisconnect(conn, reason) + local p = pending_connections_map[conn]; + if not p then + log("warn", "Failed connection, but unexpected!"); + return; + end + p.last_error = reason or "unknown reason"; + p:log("debug", "Connection attempt failed: %s", p.last_error); + attempt_connection(p); +end + +local function connect(target_resolver, listeners, options, data) + local p = setmetatable({ + id = new_id(); + target_resolver = target_resolver; + listeners = assert(listeners); + options = options or {}; + data = data; + }, pending_connection_mt); + + p:log("debug", "Starting connection process"); + attempt_connection(p); +end + +return { + connect = connect; +}; diff --git a/net/connlisteners.lua b/net/connlisteners.lua index 000bfa63..9b8f88c3 100644 --- a/net/connlisteners.lua +++ b/net/connlisteners.lua @@ -3,15 +3,15 @@ local log = require "util.logger".init("net.connlisteners"); local traceback = debug.traceback; local _ENV = nil; +-- luacheck: std none local function fail() - log("error", "Attempt to use legacy connlisteners API. For more info see http://prosody.im/doc/developers/network"); + log("error", "Attempt to use legacy connlisteners API. For more info see https://prosody.im/doc/developers/network"); log("error", "Legacy connlisteners API usage, %s", traceback("", 2)); end return { register = fail; - register = fail; get = fail; start = fail; -- epic fail diff --git a/net/cqueues.lua b/net/cqueues.lua new file mode 100644 index 00000000..8c4c756f --- /dev/null +++ b/net/cqueues.lua @@ -0,0 +1,74 @@ +-- Prosody IM +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- +-- This module allows you to use cqueues with a net.server mainloop +-- + +local server = require "net.server"; +local cqueues = require "cqueues"; +assert(cqueues.VERSION >= 20150113, "cqueues newer than 20150113 required") + +-- Create a single top level cqueue +local cq; + +if server.cq then -- server provides cqueues object + cq = server.cq; +elseif server.get_backend() == "select" and server._addtimer then -- server_select + cq = cqueues.new(); + local function step() + assert(cq:loop(0)); + end + + -- Use wrapclient (as wrapconnection isn't exported) to get server_select to watch cq fd + local handler = server.wrapclient({ + getfd = function() return cq:pollfd(); end; + settimeout = function() end; -- Method just needs to exist + close = function() end; -- Need close method for 'closeall' + }, nil, nil, {}); + + -- Only need to listen for readable; cqueues handles everything under the hood + -- readbuffer is called when `select` notes an fd as readable + handler.readbuffer = step; + + -- Use server_select low lever timer facility, + -- this callback gets called *every* time there is a timeout in the main loop + server._addtimer(function(current_time) + -- This may end up in extra step()'s, but cqueues handles it for us. + step(); + return cq:timeout(); + end); +elseif server.event and server.base then -- server_event + cq = cqueues.new(); + -- Only need to listen for readable; cqueues handles everything under the hood + local EV_READ = server.event.EV_READ; + -- Convert a cqueues timeout to an acceptable timeout for luaevent + local function luaevent_safe_timeout(cq) + local t = cq:timeout(); + -- if you give luaevent 0 or nil, it re-uses the previous timeout. + if t == 0 then + t = 0.000001; -- 1 microsecond is the smallest that works (goes into a `struct timeval`) + elseif t == nil then -- pick something big if we don't have one + t = 0x7FFFFFFF; -- largest 32bit int + end + return t + end + local event_handle; + event_handle = server.base:addevent(cq:pollfd(), EV_READ, function(e) + -- Need to reference event_handle or this callback will get collected + -- This creates a circular reference that can only be broken if event_handle is manually :close()'d + local _ = event_handle; + -- Run as many cqueues things as possible (with a timeout of 0) + -- If an error is thrown, it will break the libevent loop; but prosody resumes after logging a top level error + assert(cq:loop(0)); + return EV_READ, luaevent_safe_timeout(cq); + end, luaevent_safe_timeout(cq)); +else + error "NYI" +end + +return { + cq = cq; +} diff --git a/net/dns.lua b/net/dns.lua index 563c81a6..af5f1216 100644 --- a/net/dns.lua +++ b/net/dns.lua @@ -15,6 +15,7 @@ local socket = require "socket"; local timer = require "util.timer"; local new_ip = require "util.ip".new_ip; +local have_util_net, util_net = pcall(require, "util.net"); local _, windows = pcall(require, "util.windows"); local is_windows = (_ and windows) or os.getenv("WINDIR"); @@ -72,6 +73,7 @@ local default_timeout = 15; -------------------------------------------------- module dns local _ENV = nil; +-- luacheck: std none local dns = {}; @@ -119,11 +121,99 @@ end dns.types = { - 'A', 'NS', 'MD', 'MF', 'CNAME', 'SOA', 'MB', 'MG', 'MR', 'NULL', 'WKS', - 'PTR', 'HINFO', 'MINFO', 'MX', 'TXT', - [ 28] = 'AAAA', [ 29] = 'LOC', [ 33] = 'SRV', - [252] = 'AXFR', [253] = 'MAILB', [254] = 'MAILA', [255] = '*' }; - + [1] = "A", -- a host address,[RFC1035],, + [2] = "NS", -- an authoritative name server,[RFC1035],, + [3] = "MD", -- a mail destination (OBSOLETE - use MX),[RFC1035],, + [4] = "MF", -- a mail forwarder (OBSOLETE - use MX),[RFC1035],, + [5] = "CNAME", -- the canonical name for an alias,[RFC1035],, + [6] = "SOA", -- marks the start of a zone of authority,[RFC1035],, + [7] = "MB", -- a mailbox domain name (EXPERIMENTAL),[RFC1035],, + [8] = "MG", -- a mail group member (EXPERIMENTAL),[RFC1035],, + [9] = "MR", -- a mail rename domain name (EXPERIMENTAL),[RFC1035],, + [10] = "NULL", -- a null RR (EXPERIMENTAL),[RFC1035],, + [11] = "WKS", -- a well known service description,[RFC1035],, + [12] = "PTR", -- a domain name pointer,[RFC1035],, + [13] = "HINFO", -- host information,[RFC1035],, + [14] = "MINFO", -- mailbox or mail list information,[RFC1035],, + [15] = "MX", -- mail exchange,[RFC1035],, + [16] = "TXT", -- text strings,[RFC1035],, + [17] = "RP", -- for Responsible Person,[RFC1183],, + [18] = "AFSDB", -- for AFS Data Base location,[RFC1183][RFC5864],, + [19] = "X25", -- for X.25 PSDN address,[RFC1183],, + [20] = "ISDN", -- for ISDN address,[RFC1183],, + [21] = "RT", -- for Route Through,[RFC1183],, + [22] = "NSAP", -- "for NSAP address, NSAP style A record",[RFC1706],, + [23] = "NSAP-PTR", -- "for domain name pointer, NSAP style",[RFC1348][RFC1637][RFC1706],, + [24] = "SIG", -- for security signature,[RFC4034][RFC3755][RFC2535][RFC2536][RFC2537][RFC2931][RFC3110][RFC3008],, + [25] = "KEY", -- for security key,[RFC4034][RFC3755][RFC2535][RFC2536][RFC2537][RFC2539][RFC3008][RFC3110],, + [26] = "PX", -- X.400 mail mapping information,[RFC2163],, + [27] = "GPOS", -- Geographical Position,[RFC1712],, + [28] = "AAAA", -- IP6 Address,[RFC3596],, + [29] = "LOC", -- Location Information,[RFC1876],, + [30] = "NXT", -- Next Domain (OBSOLETE),[RFC3755][RFC2535],, + [31] = "EID", -- Endpoint Identifier,[Michael_Patton][http://ana-3.lcs.mit.edu/~jnc/nimrod/dns.txt],,1995-06 + [32] = "NIMLOC", -- Nimrod Locator,[1][Michael_Patton][http://ana-3.lcs.mit.edu/~jnc/nimrod/dns.txt],,1995-06 + [33] = "SRV", -- Server Selection,[1][RFC2782],, + [34] = "ATMA", -- ATM Address,"[ ATM Forum Technical Committee, ""ATM Name System, V2.0"", Doc ID: AF-DANS-0152.000, July 2000. Available from and held in escrow by IANA.]",, + [35] = "NAPTR", -- Naming Authority Pointer,[RFC2915][RFC2168][RFC3403],, + [36] = "KX", -- Key Exchanger,[RFC2230],, + [37] = "CERT", -- CERT,[RFC4398],, + [38] = "A6", -- A6 (OBSOLETE - use AAAA),[RFC3226][RFC2874][RFC6563],, + [39] = "DNAME", -- DNAME,[RFC6672],, + [40] = "SINK", -- SINK,[Donald_E_Eastlake][http://tools.ietf.org/html/draft-eastlake-kitchen-sink],,1997-11 + [41] = "OPT", -- OPT,[RFC6891][RFC3225],, + [42] = "APL", -- APL,[RFC3123],, + [43] = "DS", -- Delegation Signer,[RFC4034][RFC3658],, + [44] = "SSHFP", -- SSH Key Fingerprint,[RFC4255],, + [45] = "IPSECKEY", -- IPSECKEY,[RFC4025],, + [46] = "RRSIG", -- RRSIG,[RFC4034][RFC3755],, + [47] = "NSEC", -- NSEC,[RFC4034][RFC3755],, + [48] = "DNSKEY", -- DNSKEY,[RFC4034][RFC3755],, + [49] = "DHCID", -- DHCID,[RFC4701],, + [50] = "NSEC3", -- NSEC3,[RFC5155],, + [51] = "NSEC3PARAM", -- NSEC3PARAM,[RFC5155],, + [52] = "TLSA", -- TLSA,[RFC6698],, + [53] = "SMIMEA", -- S/MIME cert association,[RFC8162],SMIMEA/smimea-completed-template,2015-12-01 + -- [54] = "Unassigned", -- ,,, + [55] = "HIP", -- Host Identity Protocol,[RFC8005],, + [56] = "NINFO", -- NINFO,[Jim_Reid],NINFO/ninfo-completed-template,2008-01-21 + [57] = "RKEY", -- RKEY,[Jim_Reid],RKEY/rkey-completed-template,2008-01-21 + [58] = "TALINK", -- Trust Anchor LINK,[Wouter_Wijngaards],TALINK/talink-completed-template,2010-02-17 + [59] = "CDS", -- Child DS,[RFC7344],CDS/cds-completed-template,2011-06-06 + [60] = "CDNSKEY", -- DNSKEY(s) the Child wants reflected in DS,[RFC7344],,2014-06-16 + [61] = "OPENPGPKEY", -- OpenPGP Key,[RFC7929],OPENPGPKEY/openpgpkey-completed-template,2014-08-12 + [62] = "CSYNC", -- Child-To-Parent Synchronization,[RFC7477],,2015-01-27 + -- [63 .. 98] = "Unassigned", -- ,,, + [99] = "SPF", -- ,[RFC7208],, + [100] = "UINFO", -- ,[IANA-Reserved],, + [101] = "UID", -- ,[IANA-Reserved],, + [102] = "GID", -- ,[IANA-Reserved],, + [103] = "UNSPEC", -- ,[IANA-Reserved],, + [104] = "NID", -- ,[RFC6742],ILNP/nid-completed-template, + [105] = "L32", -- ,[RFC6742],ILNP/l32-completed-template, + [106] = "L64", -- ,[RFC6742],ILNP/l64-completed-template, + [107] = "LP", -- ,[RFC6742],ILNP/lp-completed-template, + [108] = "EUI48", -- an EUI-48 address,[RFC7043],EUI48/eui48-completed-template,2013-03-27 + [109] = "EUI64", -- an EUI-64 address,[RFC7043],EUI64/eui64-completed-template,2013-03-27 + -- [110 .. 248] = "Unassigned", -- ,,, + [249] = "TKEY", -- Transaction Key,[RFC2930],, + [250] = "TSIG", -- Transaction Signature,[RFC2845],, + [251] = "IXFR", -- incremental transfer,[RFC1995],, + [252] = "AXFR", -- transfer of an entire zone,[RFC1035][RFC5936],, + [253] = "MAILB", -- "mailbox-related RRs (MB, MG or MR)",[RFC1035],, + [254] = "MAILA", -- mail agent RRs (OBSOLETE - see MX),[RFC1035],, + [255] = "*", -- A request for all records the server/cache has available,[RFC1035][RFC6895],, + [256] = "URI", -- URI,[RFC7553],URI/uri-completed-template,2011-02-22 + [257] = "CAA", -- Certification Authority Restriction,[RFC6844],CAA/caa-completed-template,2011-04-07 + [258] = "AVC", -- Application Visibility and Control,[Wolfgang_Riedel],AVC/avc-completed-template,2016-02-26 + [259] = "DOA", -- Digital Object Architecture,[draft-durand-doa-over-dns],DOA/doa-completed-template,2017-08-30 + -- [260 .. 32767] = "Unassigned", -- ,,, + [32768] = "TA", -- DNSSEC Trust Authorities,"[Sam_Weiler][http://cameo.library.cmu.edu/][ Deploying DNSSEC Without a Signed Root. Technical Report 1999-19, Information Networking Institute, Carnegie Mellon University, April 2004.]",,2005-12-13 + [32769] = "DLV", -- DNSSEC Lookaside Validation,[RFC4431],, + -- [32770 .. 65279] = "Unassigned", -- ,,, + -- [65280 .. 65534] = "Private use", -- ,,, + -- [65535] = "Reserved", -- ,,, +} dns.classes = { 'IN', 'CS', 'CH', 'HS', [255] = '*' }; @@ -391,6 +481,12 @@ function resolver:A(rr) -- - - - - - - - - - - - - - - - - - - - - - - - A rr.a = string.format('%i.%i.%i.%i', b1, b2, b3, b4); end +if have_util_net and util_net.ntop then + function resolver:A(rr) + rr.a = util_net.ntop(self:sub(4)); + end +end + function resolver:AAAA(rr) local addr = {}; for _ = 1, rr.rdlength, 2 do @@ -411,6 +507,12 @@ function resolver:AAAA(rr) rr.aaaa = addr:gsub(zeros[1], "::", 1):gsub("^0::", "::"):gsub("::0$", "::"); end +if have_util_net and util_net.ntop then + function resolver:AAAA(rr) + rr.aaaa = util_net.ntop(self:sub(16)); + end +end + function resolver:CNAME(rr) -- - - - - - - - - - - - - - - - - - - - CNAME rr.cname = self:name(); end diff --git a/net/http.lua b/net/http.lua index effb0ef5..6e5ad67c 100644 --- a/net/http.lua +++ b/net/http.lua @@ -13,9 +13,10 @@ local util_http = require "util.http"; local events = require "util.events"; local verify_identity = require"util.x509".verify_identity; -local ssl_available = pcall(require, "ssl"); +local basic_resolver = require "net.resolvers.basic"; +local connect = require "net.connect".connect; -local server = require "net.server" +local ssl_available = pcall(require, "ssl"); local t_insert, t_concat = table.insert, table.concat; local pairs = pairs; @@ -27,6 +28,7 @@ local setmetatable = setmetatable; local log = require "util.logger".init("http"); local _ENV = nil; +-- luacheck: std none local requests = {}; -- Open requests @@ -34,9 +36,78 @@ local function make_id(req) return (tostring(req):match("%x+$")); end local listener = { default_port = 80, default_mode = "*a" }; +-- Request-related helper functions +local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); return err; end +local function log_if_failed(req, ret, ...) + if not ret then + log("error", "Request '%s': error in callback: %s", req.id, tostring((...))); + if not req.suppress_errors then + error(...); + end + end + return ...; +end + +local function destroy_request(request) + local conn = request.conn; + if conn then + request.conn = nil; + conn:close() + end +end + +local function request_reader(request, data, err) + if not request.parser then + local function error_cb(reason) + if request.callback then + request.callback(reason or "connection-closed", 0, request); + request.callback = nil; + end + destroy_request(request); + end + + if not data then + error_cb(err); + return; + end + + local function success_cb(r) + if request.callback then + request.callback(r.body, r.code, r, request); + request.callback = nil; + end + destroy_request(request); + end + local function options_cb() + return request; + end + request.parser = httpstream_new(success_cb, error_cb, "client", options_cb); + end + request.parser:feed(data); +end + +-- Connection listener callbacks function listener.onconnect(conn) local req = requests[conn]; + -- Initialize request object + req.write = function (...) return req.conn:write(...); end + local callback = req.callback; + req.callback = function (content, code, response, request) + do + local event = { http = req.http, url = req.url, request = req, response = response, content = content, code = code, callback = req.callback }; + req.http.events.fire_event("response", event); + content, code, response = event.content, event.code, event.response; + end + + log("debug", "Request '%s': Calling callback, status %s", req.id, code or "---"); + return log_if_failed(req.id, xpcall(function () return callback(content, code, response, request) end, handleerr)); + end + req.reader = request_reader; + req.state = "status"; + + requests[req.conn] = req; + -- Validate certificate if not req.insecure and conn:ssl() then local sock = conn:socket(); @@ -96,58 +167,24 @@ function listener.ondisconnect(conn, err) requests[conn] = nil; end -function listener.ondetach(conn) - requests[conn] = nil; -end - -local function destroy_request(request) - if request.conn then - request.conn = nil; - request.handler:close() - end +function listener.onattach(conn, req) + requests[conn] = req; + req.conn = conn; end -local function request_reader(request, data, err) - if not request.parser then - local function error_cb(reason) - if request.callback then - request.callback(reason or "connection-closed", 0, request); - request.callback = nil; - end - destroy_request(request); - end - - if not data then - error_cb(err); - return; - end - - local function success_cb(r) - if request.callback then - request.callback(r.body, r.code, r, request); - request.callback = nil; - end - destroy_request(request); - end - local function options_cb() - return request; - end - request.parser = httpstream_new(success_cb, error_cb, "client", options_cb); - end - request.parser:feed(data); +function listener.ondetach(conn) + requests[conn] = nil; end -local function handleerr(err) log("error", "Traceback[http]: %s", traceback(tostring(err), 2)); end -local function log_if_failed(id, ret, ...) - if not ret then - log("error", "Request '%s': error in callback: %s", id, tostring((...))); - end - return ...; +function listener.onfail(req, reason) + req.http.events.fire_event("request-connection-error", { http = req.http, request = req, url = req.url, err = reason }); + req.callback(reason or "connection failed", 0, req); end local function request(self, u, ex, callback) local req = url.parse(u); req.url = u; + req.http = self; if not (req and req.host) then callback("invalid-url", 0, req); @@ -166,7 +203,7 @@ local function request(self, u, ex, callback) if ret then return ret; end - req, u, ex, callback = event.request, event.url, event.options, event.callback; + req, u, ex, req.callback = event.request, event.url, event.options, event.callback; end local method, headers, body; @@ -204,6 +241,7 @@ local function request(self, u, ex, callback) end end req.insecure = ex.insecure; + req.suppress_errors = ex.suppress_errors; end log("debug", "Making %s %s request '%s' to %s", req.scheme:upper(), method or "GET", req.id, (ex and ex.suppress_url and host_header) or u); @@ -222,29 +260,8 @@ local function request(self, u, ex, callback) sslctx = ex and ex.sslctx or self.options and self.options.sslctx; end - local handler, conn = server.addclient(host, port_number, listener, "*a", sslctx) - if not handler then - self.events.fire_event("request-connection-error", { http = self, request = req, url = u, err = conn }); - callback(conn, 0, req); - return nil, conn; - end - req.handler, req.conn = handler, conn - req.write = function (...) return req.handler:write(...); end - - req.callback = function (content, code, response, request) - do - local event = { http = self, url = u, request = req, response = response, content = content, code = code, callback = callback }; - self.events.fire_event("response", event); - content, code, response = event.content, event.code, event.response; - end - - log("debug", "Request '%s': Calling callback, status %s", req.id, code or "---"); - return log_if_failed(req.id, xpcall(function () return callback(content, code, response, request) end, handleerr)); - end - req.reader = request_reader; - req.state = "status"; - - requests[req.handler] = req; + local http_service = basic_resolver.new(host, port_number); + connect(http_service, listener, { sslctx = sslctx }, req); self.events.fire_event("request", { http = self, request = req, url = u }); return req; @@ -264,6 +281,7 @@ end local default_http = new({ sslctx = { mode = "client", protocol = "sslv23", options = { "no_sslv2", "no_sslv3" } }; + suppress_errors = true; }); return { diff --git a/net/httpserver.lua b/net/httpserver.lua index 6e2e31b9..6b14313b 100644 --- a/net/httpserver.lua +++ b/net/httpserver.lua @@ -3,9 +3,10 @@ local log = require "util.logger".init("net.httpserver"); local traceback = debug.traceback; local _ENV = nil; +-- luacheck: std none -function fail() - log("error", "Attempt to use legacy HTTP API. For more info see http://prosody.im/doc/developers/legacy_http"); +local function fail() + log("error", "Attempt to use legacy HTTP API. For more info see https://prosody.im/doc/developers/legacy_http"); log("error", "Legacy HTTP API usage, %s", traceback("", 2)); end diff --git a/net/resolvers/basic.lua b/net/resolvers/basic.lua new file mode 100644 index 00000000..c2fd9260 --- /dev/null +++ b/net/resolvers/basic.lua @@ -0,0 +1,71 @@ +local adns = require "net.adns"; +local inet_pton = require "util.net".pton; + +local methods = {}; +local resolver_mt = { __index = methods }; + +-- Find the next target to connect to, and +-- pass it to cb() +function methods:next(cb) + if self.targets then + if #self.targets == 0 then + cb(nil); + return; + end + local next_target = table.remove(self.targets, 1); + cb(unpack(next_target, 1, 4)); + return; + end + + local targets = {}; + local n = 2; + local function ready() + n = n - 1; + if n > 0 then return; end + self.targets = targets; + self:next(cb); + end + + local is_ip = inet_pton(self.hostname); + if is_ip then + if #is_ip == 16 then + cb(self.conn_type.."6", self.hostname, self.port, self.extra); + elseif #is_ip == 4 then + cb(self.conn_type, self.hostname, self.port, self.extra); + end + return; + end + + -- Resolve DNS to target list + local dns_resolver = adns.resolver(); + dns_resolver:lookup(function (answer) + if answer then + for _, record in ipairs(answer) do + table.insert(targets, { self.conn_type, record.a, self.port, self.extra }); + end + end + ready(); + end, self.hostname, "A", "IN"); + + dns_resolver:lookup(function (answer) + if answer then + for _, record in ipairs(answer) do + table.insert(targets, { self.conn_type.."6", record.aaaa, self.port, self.extra }); + end + end + ready(); + end, self.hostname, "AAAA", "IN"); +end + +local function new(hostname, port, conn_type, extra) + return setmetatable({ + hostname = hostname; + port = port; + conn_type = conn_type or "tcp"; + extra = extra; + }, resolver_mt); +end + +return { + new = new; +}; diff --git a/net/resolvers/manual.lua b/net/resolvers/manual.lua new file mode 100644 index 00000000..c0d4e5d5 --- /dev/null +++ b/net/resolvers/manual.lua @@ -0,0 +1,25 @@ +local methods = {}; +local resolver_mt = { __index = methods }; + +-- Find the next target to connect to, and +-- pass it to cb() +function methods:next(cb) + if #self.targets == 0 then + cb(nil); + return; + end + local next_target = table.remove(self.targets, 1); + cb(unpack(next_target, 1, 4)); +end + +local function new(targets, conn_type, extra) + return setmetatable({ + conn_type = conn_type; + extra = extra; + targets = targets or {}; + }, resolver_mt); +end + +return { + new = new; +}; diff --git a/net/server.lua b/net/server.lua index 41e180fa..d8f24847 100644 --- a/net/server.lua +++ b/net/server.lua @@ -6,25 +6,76 @@ -- COPYING file in the source package for more information. -- -local use_luaevent = prosody and require "core.configmanager".get("*", "use_libevent"); +local log = require "util.logger".init("net.server"); +local server_type = prosody and require "core.configmanager".get("*", "network_backend") or "select"; +if prosody and require "core.configmanager".get("*", "use_libevent") then + server_type = "event"; +end -if use_luaevent then - use_luaevent = pcall(require, "luaevent.core"); - if not use_luaevent then +if server_type == "event" then + if not pcall(require, "luaevent.core") then log("error", "libevent not found, falling back to select()"); + server_type = "select" end end local server; - -if use_luaevent then +local set_config; +if server_type == "event" then server = require "net.server_event"; - -- Overwrite signal.signal() because we need to ask libevent to - -- handle them instead - local ok, signal = pcall(require, "util.signal"); - if ok and signal then - local _signal_signal = signal.signal; + local defaults = {}; + for k,v in pairs(server.cfg) do + defaults[k] = v; + end + function set_config(settings) + local event_settings = { + ACCEPT_DELAY = settings.accept_retry_interval; + ACCEPT_QUEUE = settings.tcp_backlog; + CLEAR_DELAY = settings.event_clear_interval; + CONNECT_TIMEOUT = settings.connect_timeout; + DEBUG = settings.debug; + HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; + MAX_CONNECTIONS = settings.max_connections; + MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; + MAX_READ_LENGTH = settings.max_receive_buffer_size; + MAX_SEND_LENGTH = settings.max_send_buffer_size; + READ_TIMEOUT = settings.read_timeout; + WRITE_TIMEOUT = settings.send_timeout; + }; + + for k,default in pairs(defaults) do + server.cfg[k] = event_settings[k] or default; + end + end +elseif server_type == "select" then + server = require "net.server_select"; + + local defaults = {}; + for k,v in pairs(server.getsettings()) do + defaults[k] = v; + end + function set_config(settings) + local select_settings = {}; + for k,default in pairs(defaults) do + select_settings[k] = settings[k] or default; + end + server.changesettings(select_settings); + end +else + server = require("net.server_"..server_type); + set_config = server.set_config; + if not server.get_backend then + function server.get_backend() + return server_type; + end + end +end + +-- If server.hook_signal exists, replace signal.signal() +local has_signal, signal = pcall(require, "util.signal"); +if has_signal then + if server.hook_signal then function signal.signal(signal_id, handler) if type(signal_id) == "string" then signal_id = signal[signal_id:upper()]; @@ -34,46 +85,22 @@ if use_luaevent then end return server.hook_signal(signal_id, handler); end + else + server.hook_signal = signal.signal; end else - use_luaevent = false; - server = require "net.server_select"; + if not server.hook_signal then + server.hook_signal = function() + return false, "signal hooking not supported" + end + end end -if prosody then +if prosody and set_config then local config_get = require "core.configmanager".get; - local defaults = {}; - for k,v in pairs(server.cfg or server.getsettings()) do - defaults[k] = v; - end local function load_config() local settings = config_get("*", "network_settings") or {}; - if use_luaevent then - local event_settings = { - ACCEPT_DELAY = settings.accept_retry_interval; - ACCEPT_QUEUE = settings.tcp_backlog; - CLEAR_DELAY = settings.event_clear_interval; - CONNECT_TIMEOUT = settings.connect_timeout; - DEBUG = settings.debug; - HANDSHAKE_TIMEOUT = settings.ssl_handshake_timeout; - MAX_CONNECTIONS = settings.max_connections; - MAX_HANDSHAKE_ATTEMPTS = settings.max_ssl_handshake_roundtrips; - MAX_READ_LENGTH = settings.max_receive_buffer_size; - MAX_SEND_LENGTH = settings.max_send_buffer_size; - READ_TIMEOUT = settings.read_timeout; - WRITE_TIMEOUT = settings.send_timeout; - }; - - for k,default in pairs(defaults) do - server.cfg[k] = event_settings[k] or default; - end - else - local select_settings = {}; - for k,default in pairs(defaults) do - select_settings[k] = settings[k] or default; - end - server.changesettings(select_settings); - end + return set_config(settings); end load_config(); prosody.events.add_handler("config-reloaded", load_config); diff --git a/net/server_epoll.lua b/net/server_epoll.lua new file mode 100644 index 00000000..0881f797 --- /dev/null +++ b/net/server_epoll.lua @@ -0,0 +1,718 @@ +-- Prosody IM +-- Copyright (C) 2016 Kim Alvefur +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +-- server_epoll +-- Server backend based on https://luarocks.org/modules/zash/lua-epoll + +local t_sort = table.sort; +local t_insert = table.insert; +local t_remove = table.remove; +local t_concat = table.concat; +local setmetatable = setmetatable; +local tostring = tostring; +local pcall = pcall; +local type = type; +local next = next; +local pairs = pairs; +local log = require "util.logger".init("server_epoll"); +local epoll = require "epoll"; +local socket = require "socket"; +local luasec = require "ssl"; +local gettime = require "util.time".now; +local createtable = require "util.table".create; +local _SOCKETINVALID = socket._SOCKETINVALID or -1; + +assert(socket.tcp6 and socket.tcp4, "Incompatible LuaSocket version"); + +local _ENV = nil; +-- luacheck: std none + +local default_config = { __index = { + read_timeout = 900; + write_timeout = 7; + tcp_backlog = 128; + accept_retry_interval = 10; + read_retry_delay = 1e-06; + connect_timeout = 20; + handshake_timeout = 60; + max_wait = 86400; + min_wait = 1e-06; +}}; +local cfg = default_config.__index; + +local fds = createtable(10, 0); -- FD -> conn + +-- Timer and scheduling -- + +local timers = {}; + +local function noop() end +local function closetimer(t) + t[1] = 0; + t[2] = noop; +end + +-- Set to true when timers have changed +local resort_timers = false; + +-- Add absolute timer +local function at(time, f) + local timer = { time, f, close = closetimer }; + t_insert(timers, timer); + resort_timers = true; + return timer; +end + +-- Add relative timer +local function addtimer(timeout, f) + return at(gettime() + timeout, f); +end + +-- Run callbacks of expired timers +-- Return time until next timeout +local function runtimers(next_delay, min_wait) + -- Any timers at all? + if not timers[1] then + return next_delay; + end + + if resort_timers then + -- Sort earliest timers to the end + t_sort(timers, function (a, b) return a[1] > b[1]; end); + resort_timers = false; + end + + -- Iterate from the end and remove completed timers + for i = #timers, 1, -1 do + local timer = timers[i]; + local t, f = timer[1], timer[2]; + -- Get time for every iteration to increase accuracy + local now = gettime(); + if t > now then + -- This timer should not fire yet + local diff = t - now; + if diff < next_delay then + next_delay = diff; + end + break; + end + local new_timeout = f(now); + if new_timeout then + -- Schedule for 'delay' from the time actually scheduled, + -- not from now, in order to prevent timer drift. + timer[1] = t + new_timeout; + resort_timers = true; + else + t_remove(timers, i); + end + end + + if resort_timers or next_delay < min_wait then + -- Timers may be added from within a timer callback. + -- Those would not be considered for next_delay, + -- and we might sleep for too long, so instead + -- we return a shorter timeout so we can + -- properly sort all new timers. + next_delay = min_wait; + end + + return next_delay; +end + +-- Socket handler interface + +local interface = {}; +local interface_mt = { __index = interface }; + +function interface_mt:__tostring() + if self.sockname and self.peername then + return ("FD %d (%s, %d, %s, %d)"):format(self:getfd(), self.peername, self.peerport, self.sockname, self.sockport); + elseif self.sockname or self.peername then + return ("FD %d (%s, %d)"):format(self:getfd(), self.sockname or self.peername, self.sockport or self.peerport); + end + return ("%s FD %d"):format(tostring(self.conn), self:getfd()); +end + +-- Replace the listener and tell the old one +function interface:setlistener(listeners) + self:on("detach"); + self.listeners = listeners; +end + +-- Call a listener callback +function interface:on(what, ...) + if not self.listeners then + log("error", "%s has no listeners", self); + return; + end + local listener = self.listeners["on"..what]; + if not listener then + -- log("debug", "Missing listener 'on%s'", what); -- uncomment for development and debugging + return; + end + local ok, err = pcall(listener, self, ...); + if not ok then + log("error", "Error calling on%s: %s", what, err); + end + return err; +end + +-- Return the file descriptor number +function interface:getfd() + if self.conn then + return self.conn:getfd(); + end + return _SOCKETINVALID; +end + +function interface:server() + return self._server or self; +end + +-- Get IP address +function interface:ip() + return self.peername or self.sockname; +end + +-- Get a port number, doesn't matter which +function interface:port() + return self.sockport or self.peerport; +end + +-- Get local port number +function interface:clientport() + return self.sockport; +end + +-- Get remote port +function interface:serverport() + if self.sockport then + return self.sockport; + elseif self._server then + self._server:port(); + end +end + +-- Return underlying socket +function interface:socket() + return self.conn; +end + +function interface:set_mode(new_mode) + self._pattern = new_mode; +end + +function interface:setoption(k, v) + -- LuaSec doesn't expose setoption :( + if self.conn.setoption then + self.conn:setoption(k, v); + end +end + +-- Timeout for detecting dead or idle sockets +function interface:setreadtimeout(t) + if t == false then + if self._readtimeout then + self._readtimeout:close(); + self._readtimeout = nil; + end + return + end + t = t or cfg.read_timeout; + if self._readtimeout then + self._readtimeout[1] = gettime() + t; + resort_timers = true; + else + self._readtimeout = addtimer(t, function () + if self:on("readtimeout") then + return cfg.read_timeout; + else + self:on("disconnect", "read timeout"); + self:destroy(); + end + end); + end +end + +-- Timeout for detecting dead sockets +function interface:setwritetimeout(t) + if t == false then + if self._writetimeout then + self._writetimeout:close(); + self._writetimeout = nil; + end + return + end + t = t or cfg.write_timeout; + if self._writetimeout then + self._writetimeout[1] = gettime() + t; + resort_timers = true; + else + self._writetimeout = addtimer(t, function () + self:on("disconnect", "write timeout"); + self:destroy(); + end); + end +end + +-- lua-epoll flag for currently requested poll state +function interface:flags() + if self._wantread then + if self._wantwrite then + return "rw"; + end + return "r"; + elseif self._wantwrite then + return "w"; + end +end + +-- Add or remove sockets or modify epoll flags +function interface:setflags(r, w) + if r ~= nil then self._wantread = r; end + if w ~= nil then self._wantwrite = w; end + local flags = self:flags(); + local currentflags = self._flags; + if flags == currentflags then + return true; + end + local fd = self:getfd(); + if fd < 0 then + self._wantread, self._wantwrite = nil, nil; + return nil, "invalid fd"; + end + local op = "mod"; + if not flags then + op = "del"; + elseif not currentflags then + op = "add"; + end + local ok, err = epoll.ctl(op, fd, flags); +-- log("debug", "epoll_ctl(%q, %d, %q) -> %s" .. (err and ", %q" or ""), +-- op, fd, flags or "", tostring(ok), err); + if not ok then return ok, err end + if op == "add" then + fds[fd] = self; + elseif op == "del" then + fds[fd] = nil; + end + self._flags = flags; + return true; +end + +-- Called when socket is readable +function interface:onreadable() + local data, err, partial = self.conn:receive(self._pattern); + if data then + self:onconnect(); + self:on("incoming", data); + else + if partial and partial ~= "" then + self:onconnect(); + self:on("incoming", partial, err); + end + if err == "wantread" then + self:setflags(true, nil); + elseif err == "wantwrite" then + self:setflags(nil, true); + elseif err ~= "timeout" then + self:on("disconnect", err); + self:destroy() + return; + end + end + if not self.conn then return; end + if self.conn:dirty() then + self:setreadtimeout(false); + self:pausefor(cfg.read_retry_delay); + else + self:setreadtimeout(); + end +end + +-- Called when socket is writable +function interface:onwritable() + self:onconnect(); + if not self.conn then return; end -- could have been closed in onconnect + local buffer = self.writebuffer; + local data = t_concat(buffer); + local ok, err, partial = self.conn:send(data); + if ok then + self:setflags(nil, false); + for i = #buffer, 1, -1 do + buffer[i] = nil; + end + self:setwritetimeout(false); + self:ondrain(); -- Be aware of writes in ondrain + return; + elseif partial then + buffer[1] = data:sub(partial+1); + for i = #buffer, 2, -1 do + buffer[i] = nil; + end + self:setwritetimeout(); + end + if err == "wantwrite" or err == "timeout" then + self:setflags(nil, true); + elseif err == "wantread" then + self:setflags(true, nil); + elseif err ~= "timeout" then + self:on("disconnect", err); + self:destroy(); + end +end + +-- The write buffer has been successfully emptied +function interface:ondrain() + return self:on("drain"); +end + +-- Add data to write buffer and set flag for wanting to write +function interface:write(data) + local buffer = self.writebuffer; + if buffer then + t_insert(buffer, data); + else + self.writebuffer = { data }; + end + self:setwritetimeout(); + self:setflags(nil, true); + return #data; +end +interface.send = interface.write; + +-- Close, possibly after writing is done +function interface:close() + if self.writebuffer and self.writebuffer[1] then + self:setflags(false, true); -- Flush final buffer contents + self.write, self.send = noop, noop; -- No more writing + log("debug", "Close %s after writing", tostring(self)); + self.ondrain = interface.close; + else + log("debug", "Close %s now", tostring(self)); + self.write, self.send = noop, noop; + self.close = noop; + self:on("disconnect"); + self:destroy(); + end +end + +function interface:destroy() + self:setflags(false, false); + self:setwritetimeout(false); + self:setreadtimeout(false); + self.onreadable = noop; + self.onwritable = noop; + self.destroy = noop; + self.close = noop; + self.on = noop; + self.conn:close(); + self.conn = nil; +end + +function interface:ssl() + return self._tls; +end + +function interface:starttls(ctx) + if ctx then self.tls = ctx; end + if self.writebuffer and self.writebuffer[1] then + log("debug", "Start TLS on %s after write", tostring(self)); + self.ondrain = interface.starttls; + self.starttls = false; + self:setflags(nil, true); -- make sure wantwrite is set + else + log("debug", "Start TLS on %s now", tostring(self)); + self:setflags(false, false); + local conn, err = luasec.wrap(self.conn, ctx or self.tls); + if not conn then + self:on("disconnect", err); + self:destroy(); + return conn, err; + end + conn:settimeout(0); + self.conn = conn; + self.ondrain = nil; + self.onwritable = interface.tlshandskake; + self.onreadable = interface.tlshandskake; + self:setflags(true, true); + self:setwritetimeout(cfg.handshake_timeout); + end +end + +function interface:tlshandskake() + self:setwritetimeout(false); + self:setreadtimeout(false); + local ok, err = self.conn:dohandshake(); + if ok then + log("debug", "TLS handshake on %s complete", tostring(self)); + self.onwritable = nil; + self.onreadable = nil; + self._tls = true; + self:on("status", "ssl-handshake-complete"); + self:init(); + elseif err == "wantread" then + log("debug", "TLS handshake on %s to wait until readable", tostring(self)); + self:setflags(true, false); + self:setreadtimeout(cfg.handshake_timeout); + elseif err == "wantwrite" then + log("debug", "TLS handshake on %s to wait until writable", tostring(self)); + self:setflags(false, true); + self:setwritetimeout(cfg.handshake_timeout); + else + log("debug", "TLS handshake error on %s: %s", tostring(self), err); + self:on("disconnect", err); + self:destroy(); + end +end + +local function wrapsocket(client, server, pattern, listeners, tls) -- luasocket object -> interface object + client:settimeout(0); + local conn = setmetatable({ + conn = client; + _server = server; + created = gettime(); + listeners = listeners; + _pattern = pattern or (server and server._pattern); + writebuffer = {}; + tls = tls; + }, interface_mt); + + if client.getpeername then + conn.peername, conn.peerport = client:getpeername(); + end + if client.getsockname then + conn.sockname, conn.sockport = client:getsockname(); + end + return conn; +end + +-- A server interface has new incoming connections waiting +-- This replaces the onreadable callback +function interface:onacceptable() + local conn, err = self.conn:accept(); + if not conn then + log("debug", "Error accepting new client: %s, server will be paused for %ds", err, cfg.accept_retry_interval); + self:pausefor(cfg.accept_retry_interval); + return; + end + local client = wrapsocket(conn, self, nil, self.listeners, self.tls); + log("debug", "New connection %s", tostring(client)); + client:init(); +end + +-- Initialization +function interface:init() + if self.tls and not self._tls then + return self:starttls(); + else + self:setwritetimeout(); + return self:setflags(true, true); + end +end + +function interface:pause() + return self:setflags(false); +end + +function interface:resume() + return self:setflags(true); +end + +-- Pause connection for some time +function interface:pausefor(t) + if self._pausefor then + self._pausefor:close(); + end + if t == false then return; end + self:setflags(false); + self._pausefor = addtimer(t, function () + self._pausefor = nil; + if self.conn and self.conn:dirty() then + self:onreadable(); + end + self:setflags(true); + end); +end + +-- Connected! +function interface:onconnect() + if self.conn and not self.peername and self.conn.getpeername then + self.peername, self.peerport = self.conn:getpeername(); + end + self.onconnect = noop; + self:on("connect"); +end + +local function addserver(addr, port, listeners, pattern, tls) + local conn, err = socket.bind(addr, port, cfg.tcp_backlog); + if not conn then return conn, err; end + conn:settimeout(0); + local server = setmetatable({ + conn = conn; + created = gettime(); + listeners = listeners; + _pattern = pattern; + onreadable = interface.onacceptable; + tls = tls; + sockname = addr; + sockport = port; + }, interface_mt); + server:setflags(true, false); + return server; +end + +-- COMPAT +local function wrapclient(conn, addr, port, listeners, pattern, tls) + local client = wrapsocket(conn, nil, pattern, listeners, tls); + if not client.peername then + client.peername, client.peerport = addr, port; + end + client:init(); + return client; +end + +-- New outgoing TCP connection +local function addclient(addr, port, listeners, pattern, tls) + local conn, err = socket.tcp(); + if not conn then return conn, err; end + conn:settimeout(0); + conn:connect(addr, port); + local client = wrapsocket(conn, nil, pattern, listeners, tls) + client:init(); + return client, conn; +end + +local function watchfd(fd, onreadable, onwriteable) + local conn = setmetatable({ + conn = fd; + onreadable = onreadable; + onwriteable = onwriteable; + close = function (self) + self:setflags(false, false); + end + }, interface_mt); + if type(fd) == "number" then + conn.getfd = function () + return fd; + end; + -- Otherwise it'll need to be something LuaSocket-compatible + end + conn:setflags(onreadable, onwriteable); + return conn; +end; + +-- Dump all data from one connection into another +local function link(from, to) + from.listeners = setmetatable({ + onincoming = function (_, data) + from:pause(); + to:write(data); + end, + }, {__index=from.listeners}); + to.listeners = setmetatable({ + ondrain = function () + from:resume(); + end, + }, {__index=to.listeners}); + from:setflags(true, nil); + to:setflags(nil, true); +end + +-- XXX What uses this? +-- net.adns +function interface:set_send(new_send) + self.send = new_send; +end + +-- Close all connections and servers +local function closeall() + for fd, conn in pairs(fds) do -- luacheck: ignore 213/fd + conn:close(); + end +end + +local quitting = nil; + +-- Signal main loop about shutdown via above upvalue +local function setquitting(quit) + if quit then + quitting = "quitting"; + closeall(); + else + quitting = nil; + end +end + +-- Main loop +local function loop(once) + repeat + local t = runtimers(cfg.max_wait, cfg.min_wait); + local fd, r, w = epoll.wait(t); + if fd then + local conn = fds[fd]; + if conn then + if r then + conn:onreadable(); + end + if w then + conn:onwritable(); + end + else + log("debug", "Removing unknown fd %d", fd); + epoll.ctl("del", fd); + end + elseif r ~= "timeout" then + log("debug", "epoll_wait error: %s", tostring(r)); + end + until once or (quitting and next(fds) == nil); + return quitting; +end + +return { + get_backend = function () return "epoll"; end; + addserver = addserver; + addclient = addclient; + add_task = addtimer; + at = at; + loop = loop; + closeall = closeall; + setquitting = setquitting; + wrapclient = wrapclient; + watchfd = watchfd; + link = link; + set_config = function (newconfig) + cfg = setmetatable(newconfig, default_config); + end; + + -- libevent emulation + event = { EV_READ = "r", EV_WRITE = "w", EV_READWRITE = "rw", EV_LEAVE = -1 }; + addevent = function (fd, mode, callback) + local function onevent(self) + local ret = self:callback(); + if ret == -1 then + self:setflags(false, false); + elseif ret then + self:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); + end + end + + local conn = setmetatable({ + getfd = function () return fd; end; + callback = callback; + onreadable = onevent; + onwritable = onevent; + close = function (self) + self:setflags(false, false); + fds[fd] = nil; + end; + }, interface_mt); + local ok, err = conn:setflags(mode == "r" or mode == "rw", mode == "w" or mode == "rw"); + if not ok then return ok, err; end + return conn; + end; +}; diff --git a/net/server_event.lua b/net/server_event.lua index 3a907349..3e949092 100644 --- a/net/server_event.lua +++ b/net/server_event.lua @@ -5,9 +5,9 @@ notes: -- when using luaevent, never register 2 or more EV_READ at one socket, same for EV_WRITE - -- you cant even register a new EV_READ/EV_WRITE callback inside another one + -- you can't even register a new EV_READ/EV_WRITE callback inside another one -- to do some of the above, use timeout events or something what will called from outside - -- dont let garbagecollect eventcallbacks, as long they are running + -- don't let garbagecollect eventcallbacks, as long they are running -- when using luasec, there are 4 cases of timeout errors: wantread or wantwrite during reading or writing --]] @@ -106,6 +106,12 @@ function interface_mt:_start_connection(plainssl) -- called from wrapclient self:_close() debug( "new connection failed. id:", self.id, "error:", self.fatalerror ) else + if EV_READWRITE == event then + if self.readcallback(event) == -1 then + -- Fatal error occurred + return -1; + end + end if plainssl and has_luasec then -- start ssl session self:starttls(self._sslctx, true) else -- normal connection @@ -116,7 +122,7 @@ function interface_mt:_start_connection(plainssl) -- called from wrapclient self.eventconnect = nil return -1 end - self.eventconnect = addevent( base, self.conn, EV_WRITE, callback, cfg.CONNECT_TIMEOUT ) + self.eventconnect = addevent( base, self.conn, EV_READWRITE, callback, cfg.CONNECT_TIMEOUT ) return true end function interface_mt:_start_session(call_onconnect) -- new session, for example after startssl @@ -151,7 +157,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed self.fatalerror = err self.conn = nil -- cannot be used anymore if call_onconnect then - self.ondisconnect = nil -- dont call this when client isnt really connected + self.ondisconnect = nil -- don't call this when client isn't really connected end self:_close() debug( "fatal error while ssl wrapping:", err ) @@ -194,7 +200,7 @@ function interface_mt:_start_ssl(call_onconnect) -- old socket will be destroyed end if self.fatalerror then if call_onconnect then - self.ondisconnect = nil -- dont call this when client isnt really connected + self.ondisconnect = nil -- don't call this when client isn't really connected end self:_close() debug( "handshake failed because:", self.fatalerror ) @@ -223,7 +229,8 @@ function interface_mt:_destroy() -- close this interface + events and call last _ = self.eventsession and self.eventsession:close( ) _ = self.eventwritetimeout and self.eventwritetimeout:close( ) _ = self.eventreadtimeout and self.eventreadtimeout:close( ) - _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) -- call ondisconnect listener (wont be the case if handshake failed on connect) + -- call ondisconnect listener (won't be the case if handshake failed on connect) + _ = self.ondisconnect and self:ondisconnect( self.fatalerror ~= "client to close" and self.fatalerror) _ = self.conn and self.conn:close( ) -- close connection _ = self._server and self._server:counter(-1); self.eventread, self.eventwrite = nil, nil @@ -510,7 +517,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx interface.writebuffer = { t_concat(interface.writebuffer) } local succ, err, byte = interface.conn:send( interface.writebuffer[1], 1, interface.writebufferlen ) --vdebug( "write data:", interface.writebuffer, "error:", err, "part:", byte ) - if succ then -- writing succesful + if succ then -- writing successful interface.writebuffer[1] = nil interface.writebufferlen = 0 interface:ondrain(); @@ -539,7 +546,7 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx return -1; end interface.eventwritetimeout = addevent( base, nil, EV_TIMEOUT, callback, cfg.WRITE_TIMEOUT ) -- reg a new timeout event - debug( "wantread during write attempt, reg it in readcallback but dont know what really happens next..." ) + debug( "wantread during write attempt, reg it in readcallback but don't know what really happens next..." ) -- hopefully this works with luasec; its simply not possible to use 2 different write events on a socket in luaevent return -1 end @@ -595,8 +602,8 @@ local function handleclient( client, ip, port, server, pattern, listener, sslctx end interface.eventreadtimeout = addevent( base, nil, EV_TIMEOUT, function( ) interface:_close() end, cfg.READ_TIMEOUT) - debug( "wantwrite during read attempt, reg it in writecallback but dont know what really happens next..." ) - -- to be honest i dont know what happens next, if it is allowed to first read, the write etc... + debug( "wantwrite during read attempt, reg it in writecallback but don't know what really happens next..." ) + -- to be honest i don't know what happens next, if it is allowed to first read, the write etc... else -- connection was closed or fatal error interface.fatalerror = err debug( "connection failed in read event:", interface.fatalerror ) @@ -767,13 +774,15 @@ end local function setquitting(yes) if yes then -- Quit now - closeallservers(); + if yes ~= "once" then + closeallservers(); + end base:loopexit(); end end local function get_backend() - return base:method(); + return "libevent " .. base:method(); end -- We need to hold onto the events to stop them @@ -811,6 +820,48 @@ local function link(sender, receiver, buffersize) sender:set_mode("*a"); end +local function add_task(delay, callback) + local event_handle; + event_handle = base:addevent(nil, 0, function () + local ret = callback(socket_gettime()); + if ret then + return 0, ret; + elseif event_handle then + return -1; + end + end + , delay); + return event_handle; +end + +local function watchfd(fd, onreadable, onwriteable) + local handle = {}; + function handle:setflags(r,w) + if r ~= nil then + if r and not self.wantread then + self.wantread = base:addevent(fd, EV_READ, function () + onreadable(self); + end); + elseif not r and self.wantread then + self.wantread:close(); + self.wantread = nil; + end + end + if w ~= nil then + if w and not self.wantwrite then + self.wantwrite = base:addevent(fd, EV_WRITE, function () + onwriteable(self); + end); + elseif not r and self.wantread then + self.wantwrite:close(); + self.wantwrite = nil; + end + end + end + handle:setflags(onreadable, onwriteable); + return handle; +end + return { cfg = cfg, base = base, @@ -826,6 +877,8 @@ return { closeall = closeallservers, get_backend = get_backend, hook_signal = hook_signal, + add_task = add_task, + watchfd = watchfd, __NAME = SCRIPT_NAME, __DATE = LAST_MODIFIED, diff --git a/net/server_select.lua b/net/server_select.lua index 12aef9d8..3b83bb6d 100644 --- a/net/server_select.lua +++ b/net/server_select.lua @@ -40,6 +40,7 @@ local coroutine = use "coroutine" local math_min = math.min local math_huge = math.huge local table_concat = table.concat +local table_insert = table.insert local string_sub = string.sub local coroutine_wrap = coroutine.wrap local coroutine_yield = coroutine.yield @@ -55,7 +56,6 @@ local getaddrinfo = luasocket.dns.getaddrinfo local ssl_wrap = ( has_luasec and luasec.wrap ) local socket_bind = luasocket.bind -local socket_sleep = luasocket.sleep local socket_select = luasocket.select --// functions //-- @@ -100,7 +100,6 @@ local _sendtraffic local _readtraffic local _selecttimeout -local _sleeptime local _tcpbacklog local _accepretry @@ -114,8 +113,6 @@ local _checkinterval local _sendtimeout local _readtimeout -local _timer - local _maxselectlen local _maxfd @@ -135,13 +132,12 @@ _fullservers = { } -- servers in a paused state while there are too many clients _readlistlen = 0 -- length of readlist _sendlistlen = 0 -- length of sendlist -_timerlistlen = 0 -- lenght of timerlist +_timerlistlen = 0 -- length of timerlist _sendtraffic = 0 -- some stats _readtraffic = 0 _selecttimeout = 1 -- timeout of socket.select -_sleeptime = 0 -- time to wait at the end of every loop _tcpbacklog = 128 -- some kind of hint to the OS _accepretry = 10 -- seconds to wait until the next attempt of a full server to accept @@ -301,7 +297,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport local bufferqueuelen = 0 -- end of buffer array local toclose - local fatalerror local needtls local bufferlen = 0 @@ -425,7 +420,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport bufferlen = bufferlen + #data if bufferlen > maxsendlen then _closelist[ handler ] = "send buffer exceeded" -- cannot close the client at the moment, have to wait to the end of the cycle - handler.write = idfalse -- dont write anymore + handler.write = idfalse -- don't write anymore return false elseif socket and not _sendlist[ socket ] then _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) @@ -517,7 +512,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return dispatch( handler, buffer, err ) else -- connections was closed or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " read error: ", tostring(err) ) - fatalerror = true _ = handler and handler:force_close( err ) return false end @@ -537,7 +531,7 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport else succ, err, count = false, "unexpected close", 0; end - if succ then -- sending succesful + if succ then -- sending successful bufferqueuelen = 0 bufferlen = 0 _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) -- delete socket from writelist @@ -557,7 +551,6 @@ wrapconnection = function( server, listeners, socket, ip, serverport, clientport return true else -- connection was closed during sending or fatal error out_put( "server.lua: client ", tostring(ip), ":", tostring(clientport), " write error: ", tostring(err) ) - fatalerror = true _ = handler and handler:force_close( err ) return false end @@ -806,7 +799,6 @@ end getsettings = function( ) return { select_timeout = _selecttimeout; - select_sleep_time = _sleeptime; tcp_backlog = _tcpbacklog; max_send_buffer_size = _maxsendlen; max_receive_buffer_size = _maxreadlen; @@ -825,7 +817,6 @@ changesettings = function( new ) return nil, "invalid settings table" end _selecttimeout = tonumber( new.select_timeout ) or _selecttimeout - _sleeptime = tonumber( new.select_sleep_time ) or _sleeptime _maxsendlen = tonumber( new.max_send_buffer_size ) or _maxsendlen _maxreadlen = tonumber( new.max_receive_buffer_size ) or _maxreadlen _checkinterval = tonumber( new.select_idle_check_interval ) or _checkinterval @@ -848,6 +839,49 @@ addtimer = function( listener ) return true end +local add_task do + local data = {}; + local new_data = {}; + + function add_task(delay, callback) + local current_time = luasocket_gettime(); + delay = delay + current_time; + if delay >= current_time then + table_insert(new_data, {delay, callback}); + else + local r = callback(current_time); + if r and type(r) == "number" then + return add_task(r, callback); + end + end + end + + addtimer(function(current_time) + if #new_data > 0 then + for _, d in pairs(new_data) do + table_insert(data, d); + end + new_data = {}; + end + + local next_time = math_huge; + for i, d in pairs(data) do + local t, callback = d[1], d[2]; + if t <= current_time then + data[i] = nil; + local r = callback(current_time); + if type(r) == "number" then + add_task(r, callback); + next_time = math_min(next_time, r); + end + else + next_time = math_min(next_time, t - current_time); + end + end + return next_time; + end); +end + stats = function( ) return _readtraffic, _sendtraffic, _readlistlen, _sendlistlen, _timerlistlen end @@ -855,31 +889,38 @@ end local quitting; local function setquitting(quit) - quitting = not not quit; + quitting = quit; end loop = function(once) -- this is the main loop of the program if quitting then return "quitting"; end if once then quitting = "once"; end - local next_timer_time = math_huge; + _currenttime = luasocket_gettime( ) repeat + -- Fire timers + local next_timer_time = math_huge; + for i = 1, _timerlistlen do + local t = _timerlist[ i ]( _currenttime ) -- fire timers + if t then next_timer_time = math_min(next_timer_time, t); end + end + local read, write, err = socket_select( _readlist, _sendlist, math_min(_selecttimeout, next_timer_time) ) - for _, socket in ipairs( write ) do -- send data waiting in writequeues + for _, socket in ipairs( read ) do -- receive data local handler = _socketlist[ socket ] if handler then - handler.sendbuffer( ) + handler.readbuffer( ) else closesocket( socket ) - out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen + out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen end end - for _, socket in ipairs( read ) do -- receive data + for _, socket in ipairs( write ) do -- send data waiting in writequeues local handler = _socketlist[ socket ] if handler then - handler.readbuffer( ) + handler.sendbuffer( ) else closesocket( socket ) - out_put "server.lua: found no handler and closed socket (readlist)" -- this can happen + out_put "server.lua: found no handler and closed socket (writelist)" -- this should not happen end end for handler, err in pairs( _closelist ) do @@ -910,29 +951,14 @@ loop = function(once) -- this is the main loop of the program end end - -- Fire timers - if _currenttime - _timer >= math_min(next_timer_time, 1) then - next_timer_time = math_huge; - for i = 1, _timerlistlen do - local t = _timerlist[ i ]( _currenttime ) -- fire timers - if t then next_timer_time = math_min(next_timer_time, t); end - end - _timer = _currenttime - else - next_timer_time = next_timer_time - (_currenttime - _timer); - end - for server, paused_time in pairs( _fullservers ) do if _currenttime - paused_time > _accepretry then _fullservers[ server ] = nil; server.resume(); end end - - -- wait some time (0 by default) - socket_sleep( _sleeptime ) until quitting; - if once and quitting == "once" then quitting = nil; return; end + if quitting == "once" then quitting = nil; return; end closeall(); return "quitting" end @@ -952,6 +978,7 @@ local wrapclient = function( socket, ip, serverport, listeners, pattern, sslctx if not handler then return nil, err end _socketlist[ socket ] = handler if not sslctx then + _readlistlen = addsocket(_readlist, socket, _readlistlen) _sendlistlen = addsocket(_sendlist, socket, _sendlistlen) if listeners.onconnect then -- When socket is writeable, call onconnect @@ -977,16 +1004,14 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ ) elseif sslctx and not has_luasec then err = "luasec not found" end - if not typ then + if getaddrinfo and not typ then local addrinfo, err = getaddrinfo(address) if not addrinfo then return nil, err end if addrinfo[1] and addrinfo[1].family == "inet6" then typ = "tcp6" - else - typ = "tcp" end end - local create = luasocket[typ] + local create = luasocket[typ or "tcp"] if type( create ) ~= "function" then err = "invalid socket type" end @@ -1002,14 +1027,54 @@ local addclient = function( address, port, listeners, pattern, sslctx, typ ) end client:settimeout( 0 ) local ok, err = client:connect( address, port ) - if ok or err == "timeout" then + if ok or err == "timeout" or err == "Operation already in progress" then return wrapclient( client, address, port, listeners, pattern, sslctx ) else return nil, err end end ---// EXPERIMENTAL //-- +local closewatcher = function (handler) + local socket = handler.conn; + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + _socketlist[ socket ] = nil +end; + +local addremove = function (handler, read, send) + local socket = handler.conn + _socketlist[ socket ] = handler + if read ~= nil then + if read then + _readlistlen = addsocket( _readlist, socket, _readlistlen ) + else + _sendlistlen = removesocket( _sendlist, socket, _sendlistlen ) + end + end + if send ~= nil then + if send then + _sendlistlen = addsocket( _sendlist, socket, _sendlistlen ) + else + _readlistlen = removesocket( _readlist, socket, _readlistlen ) + end + end +end + +local watchfd = function ( fd, onreadable, onwriteable ) + local socket = fd + if type(fd) == "number" then + socket = { getfd = function () return fd; end } + end + local handler = { + conn = socket; + readbuffer = onreadable or id; + sendbuffer = onwriteable or id; + close = closewatcher; + setflags = addremove; + }; + addremove( handler, onreadable, onwriteable ) + return handler +end ----------------------------------// BEGIN //-- @@ -1017,7 +1082,6 @@ use "setmetatable" ( _socketlist, { __mode = "k" } ) use "setmetatable" ( _readtimes, { __mode = "k" } ) use "setmetatable" ( _writetimes, { __mode = "k" } ) -_timer = luasocket_gettime( ) _starttime = luasocket_gettime( ) local function setlogger(new_logger) @@ -1032,9 +1096,11 @@ end return { _addtimer = addtimer, + add_task = add_task; addclient = addclient, wrapclient = wrapclient, + watchfd = watchfd, loop = loop, link = link, diff --git a/net/websocket.lua b/net/websocket.lua index 777b894c..469c6a58 100644 --- a/net/websocket.lua +++ b/net/websocket.lua @@ -21,9 +21,9 @@ local close_timeout = 3; -- Seconds to wait after sending close frame until clos local websockets = {}; local websocket_listeners = {}; -function websocket_listeners.ondisconnect(handler, err) - local s = websockets[handler]; - websockets[handler] = nil; +function websocket_listeners.ondisconnect(conn, err) + local s = websockets[conn]; + websockets[conn] = nil; if s.close_timer then timer.stop(s.close_timer); s.close_timer = nil; @@ -33,19 +33,19 @@ function websocket_listeners.ondisconnect(handler, err) if s.onclose then s:onclose(s.close_code, s.close_message or err); end end -function websocket_listeners.ondetach(handler) - websockets[handler] = nil; +function websocket_listeners.ondetach(conn) + websockets[conn] = nil; end local function fail(s, code, reason) log("warn", "WebSocket connection failed, closing. %d %s", code, reason); s:close(code, reason); - s.handler:close(); + s.conn:close(); return false end -function websocket_listeners.onincoming(handler, buffer, err) -- luacheck: ignore 212/err - local s = websockets[handler]; +function websocket_listeners.onincoming(conn, buffer, err) -- luacheck: ignore 212/err + local s = websockets[conn]; s.readbuffer = s.readbuffer..buffer; while true do local frame, len = frames.parse(s.readbuffer); @@ -111,7 +111,7 @@ function websocket_listeners.onincoming(handler, buffer, err) -- luacheck: ignor elseif frame.opcode == 0x9 then -- Ping frame frame.opcode = 0xA; frame.MASK = true; -- RFC 6455 6.1.5: If the data is being sent by the client, the frame(s) MUST be masked - handler:write(frames.build(frame)); + conn:write(frames.build(frame)); elseif frame.opcode == 0xA then -- Pong frame log("debug", "Received unexpected pong frame: " .. tostring(frame.data)); else @@ -126,15 +126,15 @@ local websocket_methods = {}; local function close_timeout_cb(now, timerid, s) -- luacheck: ignore 212/now 212/timerid s.close_timer = nil; log("warn", "Close timeout waiting for server to close, closing manually."); - s.handler:close(); + s.conn:close(); end function websocket_methods:close(code, reason) if self.readyState < 2 then code = code or 1000; log("debug", "closing WebSocket with code %i: %s" , code , tostring(reason)); self.readyState = 2; - local handler = self.handler; - handler:write(frames.build_close(code, reason, true)); + local conn = self.conn; + conn:write(frames.build_close(code, reason, true)); -- Do not close socket straight away, wait for acknowledgement from server. self.close_timer = timer.add_task(close_timeout, close_timeout_cb, self); elseif self.readyState == 2 then @@ -144,8 +144,8 @@ function websocket_methods:close(code, reason) timer.stop(self.close_timer); self.close_timer = nil; end - local handler = self.handler; - handler:close(); + local conn = self.conn; + conn:close(); else log("debug", "tried to close a closed WebSocket, ignoring."); end @@ -168,7 +168,7 @@ function websocket_methods:send(data, opcode) data = tostring(data); }; log("debug", "WebSocket sending frame: opcode=%0x, %i bytes", frame.opcode, #frame.data); - return self.handler:write(frames.build(frame)); + return self.conn:write(frames.build(frame)); end local websocket_metatable = { @@ -216,7 +216,7 @@ local function connect(url, ex, listeners) local s = setmetatable({ readbuffer = ""; databuffer = nil; - handler = nil; + conn = nil; close_code = nil; close_message = nil; close_timer = nil; @@ -236,6 +236,7 @@ local function connect(url, ex, listeners) method = "GET"; headers = headers; sslctx = ex.sslctx; + insecure = ex.insecure; }, function(b, c, r, http_req) if c ~= 101 or r.headers["connection"]:lower() ~= "upgrade" @@ -252,16 +253,16 @@ local function connect(url, ex, listeners) s.protocol = r.headers["sec-websocket-protocol"]; -- Take possession of socket from http + local conn = http_req.conn; http_req.conn = nil; - local handler = http_req.handler; - s.handler = handler; - websockets[handler] = s; - handler:setlistener(websocket_listeners); + s.conn = conn; + websockets[conn] = s; + conn:setlistener(websocket_listeners); log("debug", "WebSocket connected successfully to %s", url); s.readyState = 1; if s.onopen then s:onopen(); end - websocket_listeners.onincoming(handler, b); + websocket_listeners.onincoming(conn, b); end); return s; diff --git a/net/websocket/frames.lua b/net/websocket/frames.lua index 5fe96d45..ba25d261 100644 --- a/net/websocket/frames.lua +++ b/net/websocket/frames.lua @@ -21,8 +21,8 @@ local t_concat = table.concat; local s_byte = string.byte; local s_char= string.char; local s_sub = string.sub; -local s_pack = string.pack; -local s_unpack = string.unpack; +local s_pack = string.pack; -- luacheck: ignore 143 +local s_unpack = string.unpack; -- luacheck: ignore 143 if not s_pack and softreq"struct" then s_pack = softreq"struct".pack; @@ -112,9 +112,9 @@ end -- TODO: optimize local function apply_mask(str, key, from, to) from = from or 1 - if from < 0 then from = #str + from + 1 end -- negative indicies + if from < 0 then from = #str + from + 1 end -- negative indices to = to or #str - if to < 0 then to = #str + to + 1 end -- negative indicies + if to < 0 then to = #str + to + 1 end -- negative indices local key_len = #key local counter = 0; local data = {}; diff --git a/plugins/adhoc/adhoc.lib.lua b/plugins/adhoc/adhoc.lib.lua index 87415636..0b910299 100644 --- a/plugins/adhoc/adhoc.lib.lua +++ b/plugins/adhoc/adhoc.lib.lua @@ -36,30 +36,30 @@ function _M.handle_cmd(command, origin, stanza) local data, state = command:handler(dataIn, states[sessionid]); states[sessionid] = state; - local cmdtag; + local cmdreply; if data.status == "completed" then states[sessionid] = nil; - cmdtag = command:cmdtag("completed", sessionid); + cmdreply = command:cmdtag("completed", sessionid); elseif data.status == "canceled" then states[sessionid] = nil; - cmdtag = command:cmdtag("canceled", sessionid); + cmdreply = command:cmdtag("canceled", sessionid); elseif data.status == "error" then states[sessionid] = nil; local reply = st.error_reply(stanza, data.error.type, data.error.condition, data.error.message); origin.send(reply); return true; else - cmdtag = command:cmdtag("executing", sessionid); + cmdreply = command:cmdtag("executing", sessionid); data.actions = data.actions or { "complete" }; end for name, content in pairs(data) do if name == "info" then - cmdtag:tag("note", {type="info"}):text(content):up(); + cmdreply:tag("note", {type="info"}):text(content):up(); elseif name == "warn" then - cmdtag:tag("note", {type="warn"}):text(content):up(); + cmdreply:tag("note", {type="warn"}):text(content):up(); elseif name == "error" then - cmdtag:tag("note", {type="error"}):text(content.message):up(); + cmdreply:tag("note", {type="error"}):text(content.message):up(); elseif name == "actions" then local actions = st.stanza("actions", { execute = content.default }); for _, action in ipairs(content) do @@ -70,17 +70,17 @@ function _M.handle_cmd(command, origin, stanza) command.name, command.node, action); end end - cmdtag:add_child(actions); + cmdreply:add_child(actions); elseif name == "form" then - cmdtag:add_child((content.layout or content):form(content.values)); + cmdreply:add_child((content.layout or content):form(content.values)); elseif name == "result" then - cmdtag:add_child((content.layout or content):form(content.values, "result")); + cmdreply:add_child((content.layout or content):form(content.values, "result")); elseif name == "other" then - cmdtag:add_child(content); + cmdreply:add_child(content); end end local reply = st.reply(stanza); - reply:add_child(cmdtag); + reply:add_child(cmdreply); origin.send(reply); return true; diff --git a/plugins/adhoc/mod_adhoc.lua b/plugins/adhoc/mod_adhoc.lua index 1c956021..8ffdc7de 100644 --- a/plugins/adhoc/mod_adhoc.lua +++ b/plugins/adhoc/mod_adhoc.lua @@ -45,8 +45,8 @@ module:hook("host-disco-info-node", function (event) end); module:hook("host-disco-items-node", function (event) - local stanza, origin, reply, node = event.stanza, event.origin, event.reply, event.node; - if node ~= xmlns_cmd then + local stanza, reply, disco_node = event.stanza, event.reply, event.node; + if disco_node ~= xmlns_cmd then return; end diff --git a/plugins/mod_admin_adhoc.lua b/plugins/mod_admin_adhoc.lua index f3de6793..501411a2 100644 --- a/plugins/mod_admin_adhoc.lua +++ b/plugins/mod_admin_adhoc.lua @@ -3,6 +3,7 @@ -- This file is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +-- luacheck: ignore 212/self 212/data 212/state 412/err local _G = _G; @@ -95,7 +96,12 @@ local change_user_password_command_handler = adhoc_simple(change_user_password_l end local username, host, resource = jid.split(fields.accountjid); if module_host ~= host then - return { status = "completed", error = { message = "Trying to change the password of a user on " .. host .. " but command was sent to " .. module_host}}; + return { + status = "completed", + error = { + message = "Trying to change the password of a user on " .. host .. " but command was sent to " .. module_host + } + }; end if usermanager_user_exists(username, host) and usermanager_set_password(username, fields.password, host, nil) then return { status = "completed", info = "Password successfully changed" }; @@ -207,8 +213,8 @@ local get_user_password_handler = adhoc_simple(get_user_password_layout, functio return generate_error_message(err); end local user, host, resource = jid.split(fields.accountjid); - local accountjid = ""; - local password = ""; + local accountjid; + local password; if host ~= module_host then return { status = "completed", error = { message = "Tried to get password for a user on " .. host .. " but command was sent to " .. module_host } }; elseif usermanager_user_exists(user, host) then @@ -246,15 +252,15 @@ local get_user_roster_handler = adhoc_simple(get_user_roster_layout, function(fi local roster = rm_load_roster(user, host); local query = st.stanza("query", { xmlns = "jabber:iq:roster" }); - for jid in pairs(roster) do - if jid then + for contact_jid in pairs(roster) do + if contact_jid then query:tag("item", { - jid = jid, - subscription = roster[jid].subscription, - ask = roster[jid].ask, - name = roster[jid].name, + jid = contact_jid, + subscription = roster[contact_jid].subscription, + ask = roster[contact_jid].ask, + name = roster[contact_jid].name, }); - for group in pairs(roster[jid].groups) do + for group in pairs(roster[contact_jid].groups) do query:tag("group"):text(group):up(); end query:up(); @@ -299,8 +305,8 @@ local get_user_stats_handler = adhoc_simple(get_user_stats_layout, function(fiel local rostersize = 0; local IPs = ""; local resources = ""; - for jid in pairs(roster) do - if jid then + for contact_jid in pairs(roster) do + if contact_jid then rostersize = rostersize + 1; end end @@ -369,7 +375,7 @@ local list_s2s_this_result = dataforms_new { { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/s2s#list" }; { name = "sessions", type = "text-multi", label = "Connections:" }; - { name = "num_in", type = "text-single", label = "#incomming connections:" }; + { name = "num_in", type = "text-single", label = "#incoming connections:" }; { name = "num_out", type = "text-single", label = "#outgoing connections:" }; }; diff --git a/plugins/mod_admin_telnet.lua b/plugins/mod_admin_telnet.lua index 5e8d8534..5bd32987 100644 --- a/plugins/mod_admin_telnet.lua +++ b/plugins/mod_admin_telnet.lua @@ -336,6 +336,43 @@ function def_env.server:memory() return true, "OK"; end +def_env.timer = {}; + +function def_env.timer:info() + local socket = require "socket"; + local print = self.session.print; + local add_task = require"util.timer".add_task; + local h, params = add_task.h, add_task.params; + if h then + print("-- util.timer"); + for i, id in ipairs(h.ids) do + if not params[id] then + print(os.date("%F %T", h.priorities[i]), h.items[id]); + elseif not params[id].callback then + print(os.date("%F %T", h.priorities[i]), h.items[id], unpack(params[id])); + else + print(os.date("%F %T", h.priorities[i]), params[id].callback, unpack(params[id])); + end + end + end + if server.event_base then + local count = 0; + for k, v in pairs(debug.getregistry()) do + if type(v) == "function" and v.callback and v.callback == add_task._on_timer then + count = count + 1; + end + end + print(count .. " libevent callbacks"); + end + if h then + local next_time = h:peek(); + if next_time then + return true, os.date("Next event at %F %T (in %%.6fs)", next_time):format(next_time - socket.gettime()); + end + end + return true; +end + def_env.module = {}; local function get_hosts_set(hosts, module) @@ -955,14 +992,15 @@ local function check_muc(jid) return room_name, host; end -function def_env.muc:create(room_jid) +function def_env.muc:create(room_jid, config) local room_name, host = check_muc(room_jid); if not room_name then return room_name, host; end if not room_name then return nil, host end - if hosts[host].modules.muc.rooms[room_jid] then return nil, "Room exists already" end - return hosts[host].modules.muc.create_room(room_jid); + if config ~= nil and type(config) ~= "table" then return nil, "Config must be a table"; end + if hosts[host].modules.muc.get_room_from_jid(room_jid) then return nil, "Room exists already" end + return hosts[host].modules.muc.create_room(room_jid, config); end function def_env.muc:room(room_jid) @@ -970,7 +1008,7 @@ function def_env.muc:room(room_jid) if not room_name then return room_name, host; end - local room_obj = hosts[host].modules.muc.rooms[room_jid]; + local room_obj = hosts[host].modules.muc.get_room_from_jid(room_jid); if not room_obj then return nil, "No such room: "..room_jid; end @@ -984,8 +1022,8 @@ function def_env.muc:list(host) end local print = self.session.print; local c = 0; - for name in keys(host_session.modules.muc.rooms) do - print(name); + for room in host_session.modules.muc.each_room() do + print(room.jid); c = c + 1; end return true, c.." rooms"; @@ -1175,7 +1213,7 @@ function printbanner(session) if option == "short" or option == "full" then session.print("Welcome to the Prosody administration console. For a list of commands, type: help"); session.print("You may find more help on using this console in our online documentation at "); - session.print("http://prosody.im/doc/console\n"); + session.print("https://prosody.im/doc/console\n"); end if option ~= "short" and option ~= "full" and option ~= "graphic" then session.print(option); diff --git a/plugins/mod_announce.lua b/plugins/mod_announce.lua index 9327556c..ee3bb5b7 100644 --- a/plugins/mod_announce.lua +++ b/plugins/mod_announce.lua @@ -91,8 +91,6 @@ function announce_handler(self, data, state) else return { status = "executing", actions = {"next", "complete", default = "complete"}, form = announce_layout }, "executing"; end - - return true; end local adhoc_new = module:require "adhoc".new; diff --git a/plugins/mod_bosh.lua b/plugins/mod_bosh.lua index 8cda4a23..1908e5ed 100644 --- a/plugins/mod_bosh.lua +++ b/plugins/mod_bosh.lua @@ -6,9 +6,6 @@ -- COPYING file in the source package for more information. -- -module:set_global(); -- Global module - -local hosts = _G.hosts; local new_xmpp_stream = require "util.xmppstream".new; local sm = require "core.sessionmanager"; local sm_destroy_session = sm.destroy_session; @@ -16,12 +13,14 @@ local new_uuid = require "util.uuid".generate; local core_process_stanza = prosody.core_process_stanza; local st = require "util.stanza"; local logger = require "util.logger"; -local log = logger.init("mod_bosh"); +local log = module._log; local initialize_filters = require "util.filters".initialize; local math_min = math.min; -local xpcall, tostring, type = xpcall, tostring, type; +local tostring, type = tostring, type; local traceback = debug.traceback; +local runner = require"util.async".runner; local nameprep = require "util.encodings".stringprep.nameprep; +local cache = require "util.cache"; local xmlns_streams = "http://etherx.jabber.org/streams"; local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; @@ -48,33 +47,14 @@ local cross_domain = module:get_option("cross_domain_bosh", false); if cross_domain == true then cross_domain = "*"; end if type(cross_domain) == "table" then cross_domain = table.concat(cross_domain, ", "); end -local trusted_proxies = module:get_option_set("trusted_proxies", { "127.0.0.1", "::1" })._items; - -local function get_ip_from_request(request) - local ip = request.conn:ip(); - local forwarded_for = request.headers.x_forwarded_for; - if forwarded_for then - forwarded_for = forwarded_for..", "..ip; - for forwarded_ip in forwarded_for:gmatch("[^%s,]+") do - if not trusted_proxies[forwarded_ip] then - ip = forwarded_ip; - end - end - end - return ip; -end - local t_insert, t_remove, t_concat = table.insert, table.remove, table.concat; -local os_time = os.time; -- All sessions, and sessions that have no requests open -local sessions, inactive_sessions = module:shared("sessions", "inactive_sessions"); +local sessions = module:shared("sessions"); -- Used to respond to idle sessions (those with waiting requests) -local waiting_requests = module:shared("waiting_requests"); function on_destroy_request(request) log("debug", "Request destroyed: %s", tostring(request)); - waiting_requests[request] = nil; local session = sessions[request.context.sid]; if session then local requests = session.requests; @@ -88,9 +68,24 @@ function on_destroy_request(request) -- If this session now has no requests open, mark it as inactive local max_inactive = session.bosh_max_inactive; if max_inactive and #requests == 0 then - inactive_sessions[session] = os_time() + max_inactive; + if session.inactive_timer then + session.inactive_timer:stop(); + end + session.inactive_timer = module:add_timer(max_inactive, check_inactive, session, request.context, + "BOSH client silent for over "..max_inactive.." seconds"); (session.log or log)("debug", "BOSH session marked as inactive (for %ds)", max_inactive); end + if session.bosh_wait_timer then + session.bosh_wait_timer:stop(); + session.bosh_wait_timer = nil; + end + end +end + +function check_inactive(now, session, context, reason) -- luacheck: ignore 212/now + if not session.destroyed then + sessions[context.sid] = nil; + sm_destroy_session(session, reason); end end @@ -124,7 +119,7 @@ function handle_POST(event) local headers = response.headers; headers.content_type = "text/xml; charset=utf-8"; - if cross_domain and event.request.headers.origin then + if cross_domain and request.headers.origin then set_cross_domain_headers(response); end @@ -148,8 +143,14 @@ function handle_POST(event) if session then -- Session was marked as inactive, since we have -- a request open now, unmark it - if inactive_sessions[session] and #session.requests > 0 then - inactive_sessions[session] = nil; + if session.inactive_timer and #session.requests > 0 then + session.inactive_timer:stop(); + session.inactive_timer = nil; + end + + if session.bosh_wait_timer then + session.bosh_wait_timer:stop(); + session.bosh_wait_timer = nil; end local r = session.requests; @@ -177,9 +178,6 @@ function handle_POST(event) if not response.finished then -- We're keeping this request open, to respond later log("debug", "Have nothing to say, so leaving request unanswered for now"); - if session.bosh_wait then - waiting_requests[response] = os_time() + session.bosh_wait; - end end if session.bosh_terminate then @@ -187,10 +185,22 @@ function handle_POST(event) session:close(); return nil; else + if session.bosh_wait and #session.requests > 0 then + session.bosh_wait_timer = module:add_timer(session.bosh_wait, after_bosh_wait, session.requests[1], session) + end + return true; -- Inform http server we shall reply later end - elseif response.finished then - return; -- A response has been sent already + elseif response.finished or context.ignore_request then + if response.finished then + module:log("debug", "Response finished"); + end + if context.ignore_request then + module:log("debug", "Ignoring this request"); + end + -- A response has been sent already, or we're ignoring this request + -- (e.g. so a different instance of the module can handle it) + return; end module:log("warn", "Unable to associate request with a session (incomplete request?)"); local close_reply = st.stanza("body", { xmlns = xmlns_bosh, type = "terminate", @@ -198,13 +208,17 @@ function handle_POST(event) return tostring(close_reply) .. "\n"; end +function after_bosh_wait(now, request, session) -- luacheck: ignore 212 + if request.conn then + session.send(""); + end +end local function bosh_reset_stream(session) session.notopen = true; end local stream_xmlns_attr = { xmlns = "urn:ietf:params:xml:ns:xmpp-streams" }; - local function bosh_close_stream(session, reason) - (session.log or log)("info", "BOSH client disconnected"); + (session.log or log)("info", "BOSH client disconnected: %s", tostring((reason and reason.condition or reason) or "session close")); local close_reply = st.stanza("body", { xmlns = xmlns_bosh, type = "terminate", ["xmlns:stream"] = xmlns_streams }); @@ -237,21 +251,22 @@ local function bosh_close_stream(session, reason) held_request:send(response_body); end sessions[session.sid] = nil; - inactive_sessions[session] = nil; sm_destroy_session(session); end +local runner_callbacks = { }; + -- Handle the <body> tag in the request payload. function stream_callbacks.streamopened(context, attr) local request, response = context.request, context.response; - local sid = attr.sid; + local sid, rid = attr.sid, tonumber(attr.rid); log("debug", "BOSH body open (sid: %s)", sid or "<none>"); + context.rid = rid; if not sid then -- New session request context.notopen = nil; -- Signals that we accept this opening tag local to_host = nameprep(attr.to); - local rid = tonumber(attr.rid); local wait = tonumber(attr.wait); if not to_host then log("debug", "BOSH client tried to connect to invalid host: %s", tostring(attr.to)); @@ -259,12 +274,10 @@ function stream_callbacks.streamopened(context, attr) ["xmlns:stream"] = xmlns_streams, condition = "improper-addressing" }); response:send(tostring(close_reply)); return; - elseif not hosts[to_host] then - -- Unknown host - log("debug", "BOSH client tried to connect to unknown host: %s", tostring(attr.to)); - local close_reply = st.stanza("body", { xmlns = xmlns_bosh, type = "terminate", - ["xmlns:stream"] = xmlns_streams, condition = "host-unknown" }); - response:send(tostring(close_reply)); + elseif to_host ~= module.host then + -- Could be meant for a different instance of the module + -- if multiple instances are loaded with the same URL then this can happen + context.ignore_request = true; return; end if not rid or (not wait and attr.wait or wait < 0 or wait % 1 ~= 0) then @@ -275,28 +288,32 @@ function stream_callbacks.streamopened(context, attr) return; end - rid = rid - 1; wait = math_min(wait, bosh_max_wait); -- New session sid = new_uuid(); local session = { - type = "c2s_unauthed", conn = request.conn, sid = sid, rid = rid, host = to_host, + type = "c2s_unauthed", conn = request.conn, sid = sid, host = attr.to, + rid = rid - 1, -- Hack for initial session setup, "previous" rid was $current_request - 1 bosh_version = attr.ver, bosh_wait = wait, streamid = sid, - bosh_max_inactive = bosh_max_inactivity, + bosh_max_inactive = bosh_max_inactivity, bosh_responses = cache.new(BOSH_HOLD+1):table(); requests = { }, send_buffer = {}, reset_stream = bosh_reset_stream, close = bosh_close_stream, dispatch_stanza = core_process_stanza, notopen = true, log = logger.init("bosh"..sid), secure = consider_bosh_secure or request.secure, - ip = get_ip_from_request(request); + ip = request.ip; }; sessions[sid] = session; + session.thread = runner(function (stanza) + session:dispatch_stanza(stanza); + end, runner_callbacks, session); + local filter = initialize_filters(session); session.log("debug", "BOSH session created for request from %s", session.ip); log("info", "New BOSH session, assigned it sid '%s'", sid); - hosts[session.host].events.fire_event("bosh-session", { session = session, request = request }); + module:fire_event("bosh-session", { session = session, request = request }); -- Send creation response local creating_session = true; @@ -335,8 +352,9 @@ function stream_callbacks.streamopened(context, attr) body_attr["xmlns:xmpp"] = "urn:xmpp:xbosh"; body_attr["xmpp:version"] = "1.0"; end - session.bosh_last_response = st.stanza("body", body_attr):top_tag()..t_concat(session.send_buffer).."</body>"; - oldest_request:send(session.bosh_last_response); + local response_xml = st.stanza("body", body_attr):top_tag()..t_concat(session.send_buffer).."</body>"; + session.bosh_responses[oldest_request.context.rid] = response_xml; + oldest_request:send(response_xml); session.send_buffer = {}; end return true; @@ -356,24 +374,31 @@ function stream_callbacks.streamopened(context, attr) session.conn = request.conn; if session.rid then - local rid = tonumber(attr.rid); local diff = rid - session.rid; -- Diff should be 1 for a healthy request + session.log("debug", "rid: %d, sess: %s, diff: %d", rid, session.rid, diff) if diff ~= 1 then context.sid = sid; context.notopen = nil; - if diff == 2 then + if diff == 2 then -- Missed a request -- Hold request, but don't process it (ouch!) session.log("debug", "rid skipped: %d, deferring this request", rid-1) context.defer = true; session.bosh_deferred = { context = context, sid = sid, rid = rid, terminate = attr.type == "terminate" }; return; end + -- Set a marker to indicate that stanzas in this request should NOT be processed + -- (these stanzas will already be in the XML parser's buffer) context.ignore = true; - if diff == 0 then - -- Re-send previous response, ignore stanzas in this request - session.log("debug", "rid repeated, ignoring: %s (diff %d)", session.rid, diff); - response:send(session.bosh_last_response); + if session.bosh_responses[rid] then + -- Re-send past response, ignore stanzas in this request + session.log("debug", "rid repeated within window, replaying old response"); + response:send(session.bosh_responses[rid]); + return; + elseif diff == 0 then + session.log("debug", "current rid repeated, ignoring stanzas"); + t_insert(session.requests, response); + context.sid = sid; return; end -- Session broken, destroy it @@ -397,13 +422,18 @@ function stream_callbacks.streamopened(context, attr) if session.notopen then local features = st.stanza("stream:features"); - hosts[session.host].events.fire_event("stream-features", { origin = session, features = features }); + module:fire_event("stream-features", { origin = session, features = features }); session.send(features); session.notopen = nil; end end local function handleerr(err) log("error", "Traceback[bosh]: %s", traceback(tostring(err), 2)); end + +function runner_callbacks:error(err) -- luacheck: ignore 212/self + return handleerr(err); +end + function stream_callbacks.handlestanza(context, stanza) if context.ignore then return; end log("debug", "BOSH stanza received: %s\n", stanza:top_tag()); @@ -417,9 +447,7 @@ function stream_callbacks.handlestanza(context, stanza) t_insert(session.bosh_deferred, stanza); else stanza = session.filter("stanzas/in", stanza); - if stanza then - return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); - end + session.thread:run(stanza); end else log("debug", "No session for this stanza! (sid: %s)", context.sid or "none!"); @@ -432,13 +460,13 @@ function stream_callbacks.streamclosed(context) if not context.defer and session.bosh_deferred then -- Handle deferred stanzas now local deferred_stanzas = session.bosh_deferred; - local context = deferred_stanzas.context; + local deferred_context = deferred_stanzas.context; session.bosh_deferred = nil; log("debug", "Handling deferred stanzas from rid %d", deferred_stanzas.rid); session.rid = deferred_stanzas.rid; - t_insert(session.requests, context.response); + t_insert(session.requests, deferred_context.response); for _, stanza in ipairs(deferred_stanzas) do - stream_callbacks.handlestanza(context, stanza); + stream_callbacks.handlestanza(deferred_context, stanza); end if deferred_stanzas.terminate then session.bosh_terminate = true; @@ -452,8 +480,8 @@ function stream_callbacks.streamclosed(context) end function stream_callbacks.error(context, error) - log("debug", "Error parsing BOSH request payload; %s", error); if not context.sid then + log("debug", "Error parsing BOSH request payload; %s", error); local response = context.response; local close_reply = st.stanza("body", { xmlns = xmlns_bosh, type = "terminate", ["xmlns:stream"] = xmlns_streams, condition = "bad-request" }); @@ -462,6 +490,7 @@ function stream_callbacks.error(context, error) end local session = sessions[context.sid]; + (session and session.log or log)("warn", "Error parsing BOSH request payload; %s", error); if error == "stream-error" then -- Remote stream error, we close normally session:close(); else @@ -469,65 +498,25 @@ function stream_callbacks.error(context, error) end end -local dead_sessions = module:shared("dead_sessions"); -function on_timer() - -- log("debug", "Checking for requests soon to timeout..."); - -- Identify requests timing out within the next few seconds - local now = os_time() + 3; - for request, reply_before in pairs(waiting_requests) do - if reply_before <= now then - log("debug", "%s was soon to timeout (at %d, now %d), sending empty response", tostring(request), reply_before, now); - -- Send empty response to let the - -- client know we're still here - if request.conn then - sessions[request.context.sid].send(""); - end - end - end - - now = now - 3; - local n_dead_sessions = 0; - for session, close_after in pairs(inactive_sessions) do - if close_after < now then - (session.log or log)("debug", "BOSH client inactive too long, destroying session at %d", now); - sessions[session.sid] = nil; - inactive_sessions[session] = nil; - n_dead_sessions = n_dead_sessions + 1; - dead_sessions[n_dead_sessions] = session; - end - end - - for i=1,n_dead_sessions do - local session = dead_sessions[i]; - dead_sessions[i] = nil; - sm_destroy_session(session, "BOSH client silent for over "..session.bosh_max_inactive.." seconds"); - end - return 1; -end -module:add_timer(1, on_timer); - - local GET_response = { headers = { content_type = "text/html"; }; body = [[<html><body> <p>It works! Now point your BOSH client to this URL to connect to Prosody.</p> - <p>For more information see <a href="http://prosody.im/doc/setting_up_bosh">Prosody: Setting up BOSH</a>.</p> + <p>For more information see <a href="https://prosody.im/doc/setting_up_bosh">Prosody: Setting up BOSH</a>.</p> </body></html>]]; }; -function module.add_host(module) - module:depends("http"); - module:provides("http", { - default_path = "/http-bind"; - route = { - ["GET"] = GET_response; - ["GET /"] = GET_response; - ["OPTIONS"] = handle_OPTIONS; - ["OPTIONS /"] = handle_OPTIONS; - ["POST"] = handle_POST; - ["POST /"] = handle_POST; - }; - }); -end +module:depends("http"); +module:provides("http", { + default_path = "/http-bind"; + route = { + ["GET"] = GET_response; + ["GET /"] = GET_response; + ["OPTIONS"] = handle_OPTIONS; + ["OPTIONS /"] = handle_OPTIONS; + ["POST"] = handle_POST; + ["POST /"] = handle_POST; + }; +}); diff --git a/plugins/mod_c2s.lua b/plugins/mod_c2s.lua index 7f0d1b01..3816a262 100644 --- a/plugins/mod_c2s.lua +++ b/plugins/mod_c2s.lua @@ -15,9 +15,9 @@ local sessionmanager = require "core.sessionmanager"; local st = require "util.stanza"; local sm_new_session, sm_destroy_session = sessionmanager.new_session, sessionmanager.destroy_session; local uuid_generate = require "util.uuid".generate; +local runner = require "util.async".runner; -local xpcall, tostring, type = xpcall, tostring, type; -local traceback = debug.traceback; +local tostring, type = tostring, type; local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; @@ -28,6 +28,7 @@ local stream_close_timeout = module:get_option_number("c2s_close_timeout", 5); local opt_keepalives = module:get_option_boolean("c2s_tcp_keepalives", module:get_option_boolean("tcp_keepalives", true)); local measure_connections = module:measure("connections", "amount"); +local measure_ipv6 = module:measure("ipv6", "amount"); local sessions = module:shared("sessions"); local core_process_stanza = prosody.core_process_stanza; @@ -35,13 +36,19 @@ local hosts = prosody.hosts; local stream_callbacks = { default_ns = "jabber:client" }; local listener = {}; +local runner_callbacks = {}; module:hook("stats-update", function () local count = 0; - for _ in pairs(sessions) do + local ipv6 = 0; + for _, session in pairs(sessions) do count = count + 1; + if session.ip and session.ip:match(":") then + ipv6 = ipv6 + 1; + end end measure_connections(count); + measure_ipv6(ipv6); end); --- Stream events handlers @@ -134,12 +141,9 @@ function stream_callbacks.error(session, error, data) end end -local function handleerr(err) log("error", "Traceback[c2s]: %s", traceback(tostring(err), 2)); end function stream_callbacks.handlestanza(session, stanza) stanza = session.filter("stanzas/in", stanza); - if stanza then - return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); - end + session.thread:run(stanza); end --- Session methods @@ -220,6 +224,18 @@ module:hook_global("user-password-changed", function(event) end end, 200); +function runner_callbacks:ready() + self.data.conn:resume(); +end + +function runner_callbacks:waiting() + self.data.conn:pause(); +end + +function runner_callbacks:error(err) + (self.data.log or log)("error", "Traceback[c2s]: %s", err); +end + --- Port listener function listener.onconnect(conn) local session = sm_new_session(conn); @@ -256,6 +272,10 @@ function listener.onconnect(conn) session.stream:reset(); end + session.thread = runner(function (stanza) + core_process_stanza(session, stanza); + end, runner_callbacks, session); + local filter = session.filter; function session.data(data) -- Parse the data, which will store stanzas in session.pending_stanzas diff --git a/plugins/mod_component.lua b/plugins/mod_component.lua index 4a210495..743a16a3 100644 --- a/plugins/mod_component.lua +++ b/plugins/mod_component.lua @@ -38,7 +38,7 @@ end function module.add_host(module) if module:get_host_type() ~= "component" then - error("Don't load mod_component manually, it should be for a component, please see http://prosody.im/doc/components", 0); + error("Don't load mod_component manually, it should be for a component, please see https://prosody.im/doc/components", 0); end local env = module.environment; diff --git a/plugins/mod_http.lua b/plugins/mod_http.lua index a15e8cda..3f7190d2 100644 --- a/plugins/mod_http.lua +++ b/plugins/mod_http.lua @@ -50,6 +50,9 @@ end local function redir_handler(event) event.response.headers.location = event.request.path.."/"; + if event.request.url.query then + event.response.headers.location = event.response.headers.location .. "?" .. event.request.url.query + end return 301; end @@ -120,7 +123,7 @@ function module.add_host(module) module:log("warn", "App %s added handler twice for '%s', ignoring", app_name, event_name); end else - module:log("error", "Invalid route in %s, %q. See http://prosody.im/doc/developers/http#routes", app_name, key); + module:log("error", "Invalid route in %s, %q. See https://prosody.im/doc/developers/http#routes", app_name, key); end end local services = portmanager.get_active_services(); @@ -147,6 +150,31 @@ function module.add_host(module) end end +local trusted_proxies = module:get_option_set("trusted_proxies", { "127.0.0.1", "::1" })._items; + +local function get_ip_from_request(request) + local ip = request.conn:ip(); + local forwarded_for = request.headers.x_forwarded_for; + if forwarded_for then + forwarded_for = forwarded_for..", "..ip; + for forwarded_ip in forwarded_for:gmatch("[^%s,]+") do + if not trusted_proxies[forwarded_ip] then + ip = forwarded_ip; + end + end + end + return ip; +end + +module:wrap_object_event(server._events, false, function (handlers, event_name, event_data) + local request = event_data.request; + if request then + -- Not included in eg http-error events + request.ip = get_ip_from_request(request); + end + return handlers(event_name, event_data); +end); + module:provides("net", { name = "http"; listener = server.listener; diff --git a/plugins/mod_iq.lua b/plugins/mod_iq.lua index c6d62e85..87c3a467 100644 --- a/plugins/mod_iq.lua +++ b/plugins/mod_iq.lua @@ -13,7 +13,7 @@ local full_sessions = prosody.full_sessions; if module:get_host_type() == "local" then module:hook("iq/full", function(data) - -- IQ to full JID recieved + -- IQ to full JID received local origin, stanza = data.origin, data.stanza; local session = full_sessions[stanza.attr.to]; @@ -27,7 +27,7 @@ if module:get_host_type() == "local" then end module:hook("iq/bare", function(data) - -- IQ to bare JID recieved + -- IQ to bare JID received local stanza = data.stanza; local type = stanza.attr.type; @@ -44,7 +44,7 @@ module:hook("iq/bare", function(data) end); module:hook("iq/self", function(data) - -- IQ to self JID recieved + -- IQ to self JID received local stanza = data.stanza; local type = stanza.attr.type; @@ -60,7 +60,7 @@ module:hook("iq/self", function(data) end); module:hook("iq/host", function(data) - -- IQ to a local host recieved + -- IQ to a local host received local stanza = data.stanza; local type = stanza.attr.type; diff --git a/plugins/mod_legacyauth.lua b/plugins/mod_legacyauth.lua index 5edc26bb..0f41d3e7 100644 --- a/plugins/mod_legacyauth.lua +++ b/plugins/mod_legacyauth.lua @@ -35,7 +35,8 @@ module:hook("stanza/iq/jabber:iq:auth:query", function(event) local session, stanza = event.origin, event.stanza; if session.type ~= "c2s_unauthed" then - (session.sends2s or session.send)(st.error_reply(stanza, "cancel", "service-unavailable", "Legacy authentication is only allowed for unauthenticated client connections.")); + (session.sends2s or session.send)(st.error_reply(stanza, "cancel", "service-unavailable", + "Legacy authentication is only allowed for unauthenticated client connections.")); return true; end diff --git a/plugins/mod_limits.lua b/plugins/mod_limits.lua index 3fc3fcaa..914d5c44 100644 --- a/plugins/mod_limits.lua +++ b/plugins/mod_limits.lua @@ -51,18 +51,18 @@ end local default_filter_set = {}; function default_filter_set.bytes_in(bytes, session) - local throttle = session.throttle; - if throttle then - local ok, balance, outstanding = throttle:poll(#bytes, true); + local sess_throttle = session.throttle; + if sess_throttle then + local ok, balance, outstanding = sess_throttle:poll(#bytes, true); if not ok then - session.log("debug", "Session over rate limit (%d) with %d (by %d), pausing", throttle.max, #bytes, outstanding); + session.log("debug", "Session over rate limit (%d) with %d (by %d), pausing", sess_throttle.max, #bytes, outstanding); outstanding = ceil(outstanding); session.conn:pause(); -- Read no more data from the connection until there is no outstanding data local outstanding_data = bytes:sub(-outstanding); bytes = bytes:sub(1, #bytes-outstanding); timer.add_task(limits_resolution, function () if not session.conn then return; end - if throttle:peek(#outstanding_data) then + if sess_throttle:peek(#outstanding_data) then session.log("debug", "Resuming paused session"); session.conn:resume(); end diff --git a/plugins/mod_message.lua b/plugins/mod_message.lua index 0d370ec1..4b8154e0 100644 --- a/plugins/mod_message.lua +++ b/plugins/mod_message.lua @@ -63,7 +63,7 @@ local function process_to_bare(bare, origin, stanza) end module:hook("message/full", function(data) - -- message to full JID recieved + -- message to full JID received local origin, stanza = data.origin, data.stanza; local session = full_sessions[stanza.attr.to]; @@ -75,7 +75,7 @@ module:hook("message/full", function(data) end, -1); module:hook("message/bare", function(data) - -- message to bare JID recieved + -- message to bare JID received local origin, stanza = data.origin, data.stanza; return process_to_bare(stanza.attr.to or (origin.username..'@'..origin.host), origin, stanza); diff --git a/plugins/mod_muc_mam.lua b/plugins/mod_muc_mam.lua new file mode 100644 index 00000000..cfc383fc --- /dev/null +++ b/plugins/mod_muc_mam.lua @@ -0,0 +1,385 @@ +-- XEP-0313: Message Archive Management for Prosody MUC +-- Copyright (C) 2011-2017 Kim Alvefur +-- +-- This file is MIT/X11 licensed. + +if module:get_host_type() ~= "component" then + module:log("error", "mod_%s should be loaded only on a MUC component, not normal hosts", module.name); + return; +end + +local xmlns_mam = "urn:xmpp:mam:2"; +local xmlns_delay = "urn:xmpp:delay"; +local xmlns_forward = "urn:xmpp:forward:0"; +local xmlns_st_id = "urn:xmpp:sid:0"; +local xmlns_muc_user = "http://jabber.org/protocol/muc#user"; +local muc_form_enable = "muc#roomconfig_enablearchiving" + +local st = require "util.stanza"; +local rsm = require "util.rsm"; +local jid_bare = require "util.jid".bare; +local jid_split = require "util.jid".split; +local jid_prep = require "util.jid".prep; +local dataform = require "util.dataforms".new; + +local mod_muc = module:depends"muc"; +local get_room_from_jid = mod_muc.get_room_from_jid; + +local is_stanza = st.is_stanza; +local tostring = tostring; +local time_now = os.time; +local m_min = math.min; +local timestamp, timestamp_parse = require "util.datetime".datetime, require "util.datetime".parse; +local default_max_items, max_max_items = 20, module:get_option_number("max_archive_query_results", 50); + +local default_history_length = 20; +local max_history_length = module:get_option_number("max_history_messages", math.huge); + +local function get_historylength(room) + return math.min(room._data.history_length or default_history_length, max_history_length); +end + +local log_all_rooms = module:get_option_boolean("muc_log_all_rooms", false); +local log_by_default = module:get_option_boolean("muc_log_by_default", true); + +local archive_store = "muc_log"; +local archive = module:open_store(archive_store, "archive"); + +if archive.name == "null" or not archive.find then + if not archive.find then + module:log("error", "Attempt to open archive storage returned a driver without archive API support"); + module:log("error", "mod_%s does not support archiving", + archive._provided_by or archive.name and "storage_"..archive.name.."(?)" or "<unknown>"); + else + module:log("error", "Attempt to open archive storage returned null driver"); + end + module:log("info", "See https://prosody.im/doc/storage and https://prosody.im/doc/archiving for more information"); + return false; +end + +local function archiving_enabled(room) + if log_all_rooms then + return true; + end + local enabled = room._data.archiving; + if enabled == nil then + return log_by_default; + end + return enabled; +end + +if not log_all_rooms then + module:hook("muc-config-form", function(event) + local room, form = event.room, event.form; + table.insert(form, + { + name = muc_form_enable, + type = "boolean", + label = "Enable archiving?", + value = archiving_enabled(room), + } + ); + end); + + module:hook("muc-config-submitted/"..muc_form_enable, function(event) + event.room._data.archiving = event.value; + event.status_codes[event.value and "170" or "171"] = true; + end); +end + +-- Note: We ignore the 'with' field as this is internally used for stanza types +local query_form = dataform { + { name = "FORM_TYPE"; type = "hidden"; value = xmlns_mam; }; + { name = "with"; type = "jid-single"; }; + { name = "start"; type = "text-single" }; + { name = "end"; type = "text-single"; }; +}; + +-- Serve form +module:hook("iq-get/bare/"..xmlns_mam..":query", function(event) + local origin, stanza = event.origin, event.stanza; + origin.send(st.reply(stanza):add_child(query_form:form())); + return true; +end); + +-- Handle archive queries +module:hook("iq-set/bare/"..xmlns_mam..":query", function(event) + local origin, stanza = event.origin, event.stanza; + local room_jid = stanza.attr.to; + local room_node = jid_split(room_jid); + local orig_from = stanza.attr.from; + local query = stanza.tags[1]; + + local room = get_room_from_jid(room_jid); + if not room then + origin.send(st.error_reply(stanza, "cancel", "item-not-found")) + return true; + end + local from = jid_bare(orig_from); + + -- Banned or not a member of a members-only room? + local from_affiliation = room:get_affiliation(from); + if from_affiliation == "outcast" -- banned + or room:get_members_only() and not from_affiliation then -- members-only, not a member + origin.send(st.error_reply(stanza, "auth", "forbidden")) + return true; + end + + local qid = query.attr.queryid; + + -- Search query parameters + local qstart, qend; + local form = query:get_child("x", "jabber:x:data"); + if form then + local err; + form, err = query_form:data(form); + if err then + origin.send(st.error_reply(stanza, "modify", "bad-request", select(2, next(err)))); + return true; + end + qstart, qend = form["start"], form["end"]; + end + + if qstart or qend then -- Validate timestamps + local vstart, vend = (qstart and timestamp_parse(qstart)), (qend and timestamp_parse(qend)) + if (qstart and not vstart) or (qend and not vend) then + origin.send(st.error_reply(stanza, "modify", "bad-request", "Invalid timestamp")) + return true; + end + qstart, qend = vstart, vend; + end + + module:log("debug", "Archive query id %s from %s until %s)", + tostring(qid), + qstart and timestamp(qstart) or "the dawn of time", + qend and timestamp(qend) or "now"); + + -- RSM stuff + local qset = rsm.get(query); + local qmax = m_min(qset and qset.max or default_max_items, max_max_items); + local reverse = qset and qset.before or false; + + local before, after = qset and qset.before, qset and qset.after; + if type(before) ~= "string" then before = nil; end + + -- Load all the data! + local data, err = archive:find(room_node, { + start = qstart; ["end"] = qend; -- Time range + limit = qmax + 1; + before = before; after = after; + reverse = reverse; + with = "message<groupchat"; + }); + + if not data then + origin.send(st.error_reply(stanza, "cancel", "internal-server-error")); + return true; + end + local total = tonumber(err); + + local msg_reply_attr = { to = stanza.attr.from, from = stanza.attr.to }; + + local results = {}; + + -- Wrap it in stuff and deliver + local first, last; + local count = 0; + local complete = "true"; + for id, item, when in data do + count = count + 1; + if count > qmax then + complete = nil; + break; + end + local fwd_st = st.message(msg_reply_attr) + :tag("result", { xmlns = xmlns_mam, queryid = qid, id = id }) + :tag("forwarded", { xmlns = xmlns_forward }) + :tag("delay", { xmlns = xmlns_delay, stamp = timestamp(when) }):up(); + + -- Strip <x> tag, containing the original senders JID, unless the room makes this public + if room:get_whois() ~= "anyone" then + item:maptags(function (tag) + if tag.name == "x" and tag.attr.xmlns == xmlns_muc_user then + return nil; + end + return tag; + end); + end + if not is_stanza(item) then + item = st.deserialize(item); + end + item.attr.xmlns = "jabber:client"; + fwd_st:add_child(item); + + if not first then first = id; end + last = id; + + if reverse then + results[count] = fwd_st; + else + origin.send(fwd_st); + end + end + + if reverse then + for i = #results, 1, -1 do + origin.send(results[i]); + end + first, last = last, first; + end + + -- That's all folks! + module:log("debug", "Archive query %s completed", tostring(qid)); + + origin.send(st.reply(stanza) + :tag("fin", { xmlns = xmlns_mam, queryid = qid, complete = complete }) + :add_child(rsm.generate { + first = first, last = last, count = total })); + return true; +end); + +module:hook("muc-get-history", function (event) + local room = event.room; + if not archiving_enabled(room) then return end + local room_jid = room.jid; + local maxstanzas = event.maxstanzas; + local maxchars = event.maxchars; + local since = event.since; + local to = event.to; + + if maxstanzas == 0 or maxchars == 0 then + return -- No history requested + end + + if not maxstanzas or maxstanzas > get_historylength(room) then + maxstanzas = get_historylength(room); + end + + if room._history and #room._history >= maxstanzas then + return -- It can deal with this itself + end + + -- Load all the data! + local query = { + limit = maxstanzas; + start = since; + reverse = true; + with = "message<groupchat"; + } + local data, err = archive:find(jid_split(room_jid), query); + + if not data then + module:log("error", "Could not fetch history: %s", tostring(err)); + return + end + + local history, i = {}, 1; + + for id, item, when in data do + item.attr.to = to; + item:tag("delay", { xmlns = "urn:xmpp:delay", from = room_jid, stamp = timestamp(when) }):up(); -- XEP-0203 + item:tag("stanza-id", { xmlns = xmlns_st_id, by = room_jid, id = id }):up(); + if room:get_whois() ~= "anyone" then + item:maptags(function (tag) + if tag.name == "x" and tag.attr.xmlns == xmlns_muc_user then + return nil; + end + return tag; + end); + end + if maxchars then + local chars = #tostring(item); + if maxchars - chars < 0 then + break + end + maxchars = maxchars - chars; + end + history[i], i = item, i+1; + -- module:log("debug", tostring(item)); + end + function event.next_stanza() + i = i - 1; + return history[i]; + end + return true; +end, 1); + +module:hook("muc-broadcast-messages", function (event) + local room, stanza = event.room, event.stanza; + + -- Filter out <stanza-id> that claim to be from us + stanza:maptags(function (tag) + if tag.name == "stanza-id" and tag.attr.xmlns == xmlns_st_id + and jid_prep(tag.attr.by) == room.jid then + return nil; + end + if tag.name == "x" and tag.attr.xmlns == xmlns_muc_user then + return nil; + end + return tag; + end); + +end, 0); + +-- Handle messages +local function save_to_history(self, stanza) + local room_node, room_host = jid_split(self.jid); + + local stored_stanza = stanza; + + if stanza.name == "message" and self:get_whois() == "anyone" then + stored_stanza = st.clone(stanza); + local actor = jid_bare(self._occupants[stanza.attr.from].jid); + local affiliation = self:get_affiliation(actor) or "none"; + local role = self:get_role(actor) or self:get_default_role(affiliation); + stored_stanza:add_direct_child(st.stanza("x", { xmlns = xmlns_muc_user }) + :tag("item", { affiliation = affiliation; role = role; jid = actor })); + end + + -- Policy check + if not archiving_enabled(self) then return end -- Don't log + + -- And stash it + local with = stanza.name + if stanza.attr.type then + with = with .. "<" .. stanza.attr.type + end + + local id = archive:append(room_node, nil, stored_stanza, time_now(), with); + + if id then + stanza:add_direct_child(st.stanza("stanza-id", { xmlns = xmlns_st_id, by = self.jid, id = id })); + end +end + +module:hook("muc-add-history", function (event) + local room, stanza = event.room, event.stanza; + save_to_history(room, stanza); +end); + +if module:get_option_boolean("muc_log_presences", true) then + module:hook("muc-occupant-joined", function (event) + save_to_history(event.room, st.stanza("presence", { from = event.nick }):tag("x", { xmlns = "http://jabber.org/protocol/muc" })); + end); + module:hook("muc-occupant-left", function (event) + save_to_history(event.room, st.stanza("presence", { type = "unavailable", from = event.nick })); + end); +end + +if not archive.delete then + module:log("warn", "Storage driver %s does not support deletion", archive._provided_by); + module:log("warn", "Archived message will persist after a room has been destroyed"); +else + module:hook("muc-room-destroyed", function(event) + local room_node = jid_split(event.room.jid); + archive:delete(room_node); + end); +end + +-- And role/affiliation changes? + +module:add_feature(xmlns_mam); + +module:hook("muc-disco#info", function(event) + event.reply:tag("feature", {var=xmlns_mam}):up(); +end); diff --git a/plugins/mod_muc_unique.lua b/plugins/mod_muc_unique.lua new file mode 100644 index 00000000..13284745 --- /dev/null +++ b/plugins/mod_muc_unique.lua @@ -0,0 +1,12 @@ +-- XEP-0307: Unique Room Names for Multi-User Chat +local st = require "util.stanza"; +local unique_name = require "util.id".medium; +module:add_feature "http://jabber.org/protocol/muc#unique" +module:hook("iq-get/host/http://jabber.org/protocol/muc#unique:unique", function(event) + local origin, stanza = event.origin, event.stanza; + origin.send(st.reply(stanza) + :tag("unique", {xmlns = "http://jabber.org/protocol/muc#unique"}) + :text(unique_name():lower()) + ); + return true; +end,-1); diff --git a/plugins/mod_pep.lua b/plugins/mod_pep.lua index 1025be37..8ee97757 100644 --- a/plugins/mod_pep.lua +++ b/plugins/mod_pep.lua @@ -18,6 +18,8 @@ local calculate_hash = require "util.caps".calculate_hash; local core_post_stanza = prosody.core_post_stanza; local bare_sessions = prosody.bare_sessions; +local xmlns_pubsub = "http://jabber.org/protocol/pubsub"; + -- Used as canonical 'empty table' local NULL = {}; -- data[user_bare_jid][node] = item_stanza @@ -36,9 +38,6 @@ module.restore = function(state) hash_map = state.hash_map or {}; end -module:add_identity("pubsub", "pep", module:get_option_string("name", "Prosody")); -module:add_feature("http://jabber.org/protocol/pubsub#publish"); - local function subscription_presence(user_bare, recipient) local recipient_bare = jid_bare(recipient); if (recipient_bare == user_bare) then return true end @@ -118,7 +117,7 @@ local function get_caps_hash_from_presence(stanza, current) end module:hook("presence/bare", function(event) - -- inbound presence to bare JID recieved + -- inbound presence to bare JID received local origin, stanza = event.origin, event.stanza; local user = stanza.attr.to or (origin.username..'@'..origin.host); local t = stanza.attr.type; @@ -284,7 +283,23 @@ end); module:hook("account-disco-info", function(event) local reply = event.reply; reply:tag('identity', {category='pubsub', type='pep'}):up(); - reply:tag('feature', {var='http://jabber.org/protocol/pubsub#publish'}):up(); + reply:tag('feature', {var=xmlns_pubsub}):up(); + local features = { + "access-presence", + "auto-create", + "auto-subscribe", + "filtered-notifications", + "item-ids", + "last-published", + "presence-notifications", + "presence-subscribe", + "publish", + "retract-items", + "retrieve-items", + }; + for _, feature in ipairs(features) do + reply:tag('feature', {var=xmlns_pubsub.."#"..feature}):up(); + end end); module:hook("account-disco-items", function(event) diff --git a/plugins/mod_pep_plus.lua b/plugins/mod_pep_plus.lua new file mode 100644 index 00000000..92a41719 --- /dev/null +++ b/plugins/mod_pep_plus.lua @@ -0,0 +1,476 @@ +local pubsub = require "util.pubsub"; +local jid_bare = require "util.jid".bare; +local jid_split = require "util.jid".split; +local jid_join = require "util.jid".join; +local set_new = require "util.set".new; +local st = require "util.stanza"; +local calculate_hash = require "util.caps".calculate_hash; +local is_contact_subscribed = require "core.rostermanager".is_contact_subscribed; +local cache = require "util.cache"; +local set = require "util.set"; + +local xmlns_pubsub = "http://jabber.org/protocol/pubsub"; +local xmlns_pubsub_event = "http://jabber.org/protocol/pubsub#event"; +local xmlns_pubsub_owner = "http://jabber.org/protocol/pubsub#owner"; + +local lib_pubsub = module:require "pubsub"; + +local empty_set = set_new(); + +local services = {}; +local recipients = {}; +local hash_map = {}; + +local host = module.host; + +local known_nodes_map = module:open_store("pep", "map"); +local known_nodes = module:open_store("pep"); + +function module.save() + return { services = services }; +end + +function module.restore(data) + services = data.services; +end + +function is_item_stanza(item) + return st.is_stanza(item) and item.attr.xmlns == xmlns_pubsub and item.name == "item"; +end + +local function subscription_presence(username, recipient) + local user_bare = jid_join(username, host); + local recipient_bare = jid_bare(recipient); + if (recipient_bare == user_bare) then return true; end + return is_contact_subscribed(username, host, recipient_bare); +end + +local function simple_itemstore(username) + return function (config, node) + if config["persist_items"] then + module:log("debug", "Creating new persistent item store for user %s, node %q", username, node); + known_nodes_map:set(username, node, true); + local archive = module:open_store("pep_"..node, "archive"); + return lib_pubsub.archive_itemstore(archive, config, username, node, false); + else + module:log("debug", "Creating new ephemeral item store for user %s, node %q", username, node); + known_nodes_map:set(username, node, nil); + return cache.new(tonumber(config["max_items"])); + end + end +end + +local function get_broadcaster(username) + local user_bare = jid_join(username, host); + local function simple_broadcast(kind, node, jids, item) + local message = st.message({ from = user_bare, type = "headline" }) + :tag("event", { xmlns = xmlns_pubsub_event }) + :tag(kind, { node = node }); + if item then + item = st.clone(item); + item.attr.xmlns = nil; -- Clear the pubsub namespace + message:add_child(item); + end + for jid in pairs(jids) do + module:log("debug", "Sending notification to %s from %s: %s", jid, user_bare, tostring(item)); + message.attr.to = jid; + module:send(message); + end + end + return simple_broadcast; +end + +function get_pep_service(username) + module:log("debug", "get_pep_service(%q)", username); + local user_bare = jid_join(username, host); + local service = services[username]; + if service then + return service; + end + service = pubsub.new({ + capabilities = { + none = { + create = false; + publish = false; + retract = false; + get_nodes = false; + + subscribe = false; + unsubscribe = false; + get_subscription = false; + get_subscriptions = false; + get_items = false; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + subscriber = { + create = false; + publish = false; + retract = false; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + publisher = { + create = false; + publish = true; + retract = true; + get_nodes = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + subscribe_other = false; + unsubscribe_other = false; + get_subscription_other = false; + get_subscriptions_other = false; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = false; + }; + owner = { + create = true; + publish = true; + retract = true; + delete = true; + get_nodes = true; + configure = true; + + subscribe = true; + unsubscribe = true; + get_subscription = true; + get_subscriptions = true; + get_items = true; + + + subscribe_other = true; + unsubscribe_other = true; + get_subscription_other = true; + get_subscriptions_other = true; + + be_subscribed = true; + be_unsubscribed = true; + + set_affiliation = true; + }; + }; + + node_defaults = { + ["max_items"] = 1; + ["persist_items"] = true; + }; + + autocreate_on_publish = true; + autocreate_on_subscribe = true; + + itemstore = simple_itemstore(username); + broadcaster = get_broadcaster(username); + itemcheck = is_item_stanza; + get_affiliation = function (jid) + if jid_bare(jid) == user_bare then + return "owner"; + elseif subscription_presence(username, jid) then + return "subscriber"; + end + end; + + normalize_jid = jid_bare; + }); + local nodes, err = known_nodes:get(username); + if nodes then + module:log("debug", "Restoring nodes for user %s", username); + for node in pairs(nodes) do + module:log("debug", "Restoring node %q", node); + service:create(node, true); + end + elseif err then + module:log("error", "Could not restore nodes for %s: %s", username, err); + else + module:log("debug", "No known nodes"); + end + services[username] = service; + module:add_item("pep-service", { service = service, jid = user_bare }); + return service; +end + +function handle_pubsub_iq(event) + local origin, stanza = event.origin, event.stanza; + local service_name = origin.username; + if stanza.attr.to ~= nil then + service_name = jid_split(stanza.attr.to); + end + local service = get_pep_service(service_name); + + return lib_pubsub.handle_pubsub_iq(event, service) +end + +module:hook("iq/bare/"..xmlns_pubsub..":pubsub", handle_pubsub_iq); +module:hook("iq/bare/"..xmlns_pubsub_owner..":pubsub", handle_pubsub_iq); + +module:add_identity("pubsub", "pep", module:get_option_string("name", "Prosody")); +module:add_feature("http://jabber.org/protocol/pubsub#publish"); + +local function get_caps_hash_from_presence(stanza, current) + local t = stanza.attr.type; + if not t then + local child = stanza:get_child("c", "http://jabber.org/protocol/caps"); + if child then + local attr = child.attr; + if attr.hash then -- new caps + if attr.hash == 'sha-1' and attr.node and attr.ver then + return attr.ver, attr.node.."#"..attr.ver; + end + else -- legacy caps + if attr.node and attr.ver then + return attr.node.."#"..attr.ver.."#"..(attr.ext or ""), attr.node.."#"..attr.ver; + end + end + end + return; -- no or bad caps + elseif t == "unavailable" or t == "error" then + return; + end + return current; -- no caps, could mean caps optimization, so return current +end + +local function resend_last_item(jid, node, service) + local ok, id, item = service:get_last_item(node, jid); + if not ok then return; end + if not id then return; end + service.config.broadcaster("items", node, { [jid] = true }, item); +end + +local function update_subscriptions(recipient, service_name, nodes) + nodes = nodes or empty_set; + + local service_recipients = recipients[service_name]; + if not service_recipients then + service_recipients = {}; + recipients[service_name] = service_recipients; + end + + local current = service_recipients[recipient]; + if not current or type(current) ~= "table" then + current = empty_set; + end + + if (current == empty_set or current:empty()) and (nodes == empty_set or nodes:empty()) then + return; + end + + local service = get_pep_service(service_name); + for node in current - nodes do + service:remove_subscription(node, recipient, recipient); + end + + for node in nodes - current do + service:add_subscription(node, recipient, recipient); + resend_last_item(recipient, node, service); + end + + if nodes == empty_set or nodes:empty() then + nodes = nil; + end + + service_recipients[recipient] = nodes; +end + +module:hook("presence/bare", function(event) + -- inbound presence to bare JID received + local origin, stanza = event.origin, event.stanza; + local t = stanza.attr.type; + local is_self = not stanza.attr.to; + local username = jid_split(stanza.attr.to); + local user_bare = jid_bare(stanza.attr.to); + if is_self then + username = origin.username; + user_bare = jid_join(username, host); + end + + if not t then -- available presence + if is_self or subscription_presence(username, stanza.attr.from) then + local recipient = stanza.attr.from; + local current = recipients[username] and recipients[username][recipient]; + local hash, query_node = get_caps_hash_from_presence(stanza, current); + if current == hash or (current and current == hash_map[hash]) then return; end + if not hash then + update_subscriptions(recipient, username); + else + recipients[username] = recipients[username] or {}; + if hash_map[hash] then + update_subscriptions(recipient, username, hash_map[hash]); + else + recipients[username][recipient] = hash; + local from_bare = origin.type == "c2s" and origin.username.."@"..origin.host; + if is_self or origin.type ~= "c2s" or (recipients[from_bare] and recipients[from_bare][origin.full_jid]) ~= hash then + -- COMPAT from ~= stanza.attr.to because OneTeam can't deal with missing from attribute + origin.send( + st.stanza("iq", {from=user_bare, to=stanza.attr.from, id="disco", type="get"}) + :tag("query", {xmlns = "http://jabber.org/protocol/disco#info", node = query_node}) + ); + end + end + end + end + elseif t == "unavailable" then + update_subscriptions(stanza.attr.from, username); + elseif not is_self and t == "unsubscribe" then + local from = jid_bare(stanza.attr.from); + local subscriptions = recipients[username]; + if subscriptions then + for subscriber in pairs(subscriptions) do + if jid_bare(subscriber) == from then + update_subscriptions(subscriber, username); + end + end + end + end +end, 10); + +module:hook("iq-result/bare/disco", function(event) + local origin, stanza = event.origin, event.stanza; + local disco = stanza:get_child("query", "http://jabber.org/protocol/disco#info"); + if not disco then + return; + end + + -- Process disco response + local is_self = stanza.attr.to == nil; + local user_bare = jid_bare(stanza.attr.to); + local username = jid_split(stanza.attr.to); + if is_self then + username = origin.username; + user_bare = jid_join(username, host); + end + local contact = stanza.attr.from; + local current = recipients[username] and recipients[username][contact]; + if type(current) ~= "string" then return; end -- check if waiting for recipient's response + local ver = current; + if not string.find(current, "#") then + ver = calculate_hash(disco.tags); -- calculate hash + end + local notify = set_new(); + for _, feature in pairs(disco.tags) do + if feature.name == "feature" and feature.attr.var then + local nfeature = feature.attr.var:match("^(.*)%+notify$"); + if nfeature then notify:add(nfeature); end + end + end + hash_map[ver] = notify; -- update hash map + if is_self then + -- Optimization: Fiddle with other local users + for jid, item in pairs(origin.roster) do -- for all interested contacts + if jid then + local contact_node, contact_host = jid_split(jid); + if contact_host == host and item.subscription == "both" or item.subscription == "from" then + update_subscriptions(user_bare, contact_node, notify); + end + end + end + end + update_subscriptions(contact, username, notify); +end); + +module:hook("account-disco-info-node", function(event) + local reply, stanza, origin = event.reply, event.stanza, event.origin; + local service_name = origin.username; + if stanza.attr.to ~= nil then + service_name = jid_split(stanza.attr.to); + end + local service = get_pep_service(service_name); + local node = event.node; + local ok = service:get_items(node, jid_bare(stanza.attr.from) or true); + if not ok then return; end + event.exists = true; + reply:tag('identity', {category='pubsub', type='leaf'}):up(); +end); + +module:hook("account-disco-info", function(event) + local origin, reply = event.origin, event.reply; + + reply:tag('identity', {category='pubsub', type='pep'}):up(); + + local username = jid_split(reply.attr.from) or origin.username; + local service = get_pep_service(username); + + local supported_features = lib_pubsub.get_feature_set(service) + set.new{ + -- Features not covered by the above + "access-presence", + "auto-subscribe", + "filtered-notifications", + "last-published", + "persistent-items", + "presence-notifications", + "presence-subscribe", + }; + + for feature in supported_features do + reply:tag('feature', {var=xmlns_pubsub.."#"..feature}):up(); + end +end); + +module:hook("account-disco-items-node", function(event) + local reply, stanza, origin = event.reply, event.stanza, event.origin; + local node = event.node; + local is_self = stanza.attr.to == nil; + local user_bare = jid_bare(stanza.attr.to); + local username = jid_split(stanza.attr.to); + if is_self then + username = origin.username; + user_bare = jid_join(username, host); + end + local service = get_pep_service(username); + local ok, ret = service:get_items(node, jid_bare(stanza.attr.from) or true); + if not ok then return; end + event.exists = true; + for _, id in ipairs(ret) do + reply:tag("item", { jid = user_bare, name = id }):up(); + end +end); + +module:hook("account-disco-items", function(event) + local reply, stanza, origin = event.reply, event.stanza, event.origin; + + local is_self = stanza.attr.to == nil; + local user_bare = jid_bare(stanza.attr.to); + local username = jid_split(stanza.attr.to); + if is_self then + username = origin.username; + user_bare = jid_join(username, host); + end + local service = get_pep_service(username); + + local ok, ret = service:get_nodes(jid_bare(stanza.attr.from)); + if not ok then return; end + + for node, node_obj in pairs(ret) do + reply:tag("item", { jid = user_bare, node = node, name = node_obj.config.name }):up(); + end +end); diff --git a/plugins/mod_ping.lua b/plugins/mod_ping.lua index 1a503409..5559c5ca 100644 --- a/plugins/mod_ping.lua +++ b/plugins/mod_ping.lua @@ -21,7 +21,7 @@ module:hook("iq-get/host/urn:xmpp:ping:ping", ping_handler); local datetime = require "util.datetime".datetime; -function ping_command_handler (self, data, state) +function ping_command_handler (self, data, state) -- luacheck: ignore 212 local now = datetime(); return { info = "Pong\n"..now, status = "completed" }; end diff --git a/plugins/mod_posix.lua b/plugins/mod_posix.lua index fccc7a2b..825d3be0 100644 --- a/plugins/mod_posix.lua +++ b/plugins/mod_posix.lua @@ -61,7 +61,7 @@ if not prosody.start_time then -- server-starting if not suid or suid == 0 or suid == "root" then if pposix.getuid() == 0 and not module:get_option_boolean("run_as_root") then module:log("error", "Danger, Will Robinson! Prosody doesn't need to be run as root, so don't do it!"); - module:log("error", "For more information on running Prosody as root, see http://prosody.im/doc/root"); + module:log("error", "For more information on running Prosody as root, see https://prosody.im/doc/root"); prosody.shutdown("Refusing to run as root"); end end @@ -161,23 +161,25 @@ module:hook("server-stopped", remove_pidfile); -- Set signal handlers if have_signal then - signal.signal("SIGTERM", function () - module:log("warn", "Received SIGTERM"); - prosody.unlock_globals(); - prosody.shutdown("Received SIGTERM"); - prosody.lock_globals(); - end); - - signal.signal("SIGHUP", function () - module:log("info", "Received SIGHUP"); - prosody.reload_config(); - prosody.reopen_logfiles(); - end); - - signal.signal("SIGINT", function () - module:log("info", "Received SIGINT"); - prosody.unlock_globals(); - prosody.shutdown("Received SIGINT"); - prosody.lock_globals(); + module:add_timer(0, function () + signal.signal("SIGTERM", function () + module:log("warn", "Received SIGTERM"); + prosody.unlock_globals(); + prosody.shutdown("Received SIGTERM"); + prosody.lock_globals(); + end); + + signal.signal("SIGHUP", function () + module:log("info", "Received SIGHUP"); + prosody.reload_config(); + prosody.reopen_logfiles(); + end); + + signal.signal("SIGINT", function () + module:log("info", "Received SIGINT"); + prosody.unlock_globals(); + prosody.shutdown("Received SIGINT"); + prosody.lock_globals(); + end); end); end diff --git a/plugins/mod_presence.lua b/plugins/mod_presence.lua index 0c243bc6..5056a3a3 100644 --- a/plugins/mod_presence.lua +++ b/plugins/mod_presence.lua @@ -10,7 +10,6 @@ local log = module._log; local require = require; local pairs = pairs; -local t_concat = table.concat; local s_find = string.find; local tonumber = tonumber; @@ -121,6 +120,8 @@ function handle_normal_presence(origin, stanza) stanza.attr.to = nil; -- reset it end +-- luacheck: ignore 212/recipient_session +-- TODO This argument is used in 3rd party modules function send_presence_of_available_resources(user, host, jid, recipient_session, stanza) local h = hosts[host]; local count = 0; @@ -252,7 +253,7 @@ function handle_inbound_presence_subscriptions_and_probes(origin, stanza, from_b end local outbound_presence_handler = function(data) - -- outbound presence recieved + -- outbound presence received local origin, stanza = data.origin, data.stanza; local to = stanza.attr.to; @@ -280,7 +281,7 @@ module:hook("pre-presence/bare", outbound_presence_handler); module:hook("pre-presence/host", outbound_presence_handler); module:hook("presence/bare", function(data) - -- inbound presence to bare JID recieved + -- inbound presence to bare JID received local origin, stanza = data.origin, data.stanza; local to = stanza.attr.to; @@ -306,7 +307,7 @@ module:hook("presence/bare", function(data) return true; end); module:hook("presence/full", function(data) - -- inbound presence to full JID recieved + -- inbound presence to full JID received local origin, stanza = data.origin, data.stanza; local t = stanza.attr.type; diff --git a/plugins/mod_privacy.lua b/plugins/mod_privacy.lua index b749b7c7..bb9c0253 100644 --- a/plugins/mod_privacy.lua +++ b/plugins/mod_privacy.lua @@ -7,7 +7,7 @@ -- COPYING file in the source package for more information. -- - +-- luacheck: ignore 631 -- COMPAT w/ pre 0.10 module:log("error", "The mod_privacy plugin has been replaced by mod_blocklist. Please update your config. For more information see https://prosody.im/doc/modules/mod_privacy"); module:depends("blocklist"); diff --git a/plugins/mod_private.lua b/plugins/mod_private.lua index c01053d5..9375cf80 100644 --- a/plugins/mod_private.lua +++ b/plugins/mod_private.lua @@ -9,7 +9,7 @@ local st = require "util.stanza" -local private_storage = module:open_store(); +local private_storage = module:open_store("private", "map"); module:add_feature("jabber:iq:private"); @@ -22,28 +22,23 @@ module:hook("iq/self/jabber:iq:private:query", function(event) end local tag = query.tags[1]; local key = tag.name..":"..tag.attr.xmlns; - local data, err = private_storage:get(origin.username); - if err then - origin.send(st.error_reply(stanza, "wait", "internal-server-error", err)); - return true; - end if stanza.attr.type == "get" then - if data and data[key] then - origin.send(st.reply(stanza):query("jabber:iq:private"):add_child(st.deserialize(data[key]))); - return true; + local data, err = private_storage:get(origin.username, key); + if data then + origin.send(st.reply(stanza):query("jabber:iq:private"):add_child(st.deserialize(data))); + elseif err then + origin.send(st.error_reply(stanza, "wait", "internal-server-error", err)); else origin.send(st.reply(stanza):add_child(query)); - return true; end + return true; else -- type == set - if not data then data = {}; end; - if #tag == 0 then - data[key] = nil; - else - data[key] = st.preserialize(tag); + local data; + if #tag ~= 0 then + data = st.preserialize(tag); end -- TODO delete datastore if empty - local ok, err = private_storage:set(origin.username, data); + local ok, err = private_storage:set(origin.username, key, data); if not ok then origin.send(st.error_reply(stanza, "wait", "internal-server-error", err)); return true; diff --git a/plugins/mod_proxy65.lua b/plugins/mod_proxy65.lua index cbbfad12..5d05a2d9 100644 --- a/plugins/mod_proxy65.lua +++ b/plugins/mod_proxy65.lua @@ -44,7 +44,7 @@ function listener.onincoming(conn, data) end -- else error, unexpected input conn:write("\5\255"); -- send (SOCKS version 5, no acceptable method) conn:close(); - module:log("debug", "Invalid SOCKS5 greeting recieved: '%s'", b64(data)); + module:log("debug", "Invalid SOCKS5 greeting received: '%s'", b64(data)); else -- connection request --local head = string.char( 0x05, 0x01, 0x00, 0x03, 40 ); -- ( VER=5=SOCKS5, CMD=1=CONNECT, RSV=0=RESERVED, ATYP=3=DOMAIMNAME, SHA-1 size ) if #data == 47 and data:sub(1,5) == "\5\1\0\3\40" and data:sub(-2) == "\0\0" then @@ -66,12 +66,12 @@ function listener.onincoming(conn, data) else -- error, unexpected input conn:write("\5\1\0\3\0\0\0"); -- VER, REP, RSV, ATYP, BND.ADDR (sha), BND.PORT (2 Byte) conn:close(); - module:log("debug", "Invalid SOCKS5 negotiation recieved: '%s'", b64(data)); + module:log("debug", "Invalid SOCKS5 negotiation received: '%s'", b64(data)); end end end -function listener.ondisconnect(conn, err) +function listener.ondisconnect(conn) local session = sessions[conn]; if session then if transfers[session.sha] then @@ -79,7 +79,7 @@ function listener.ondisconnect(conn, err) if initiator == conn and target ~= nil then target:close(); elseif target == conn and initiator ~= nil then - initiator:close(); + initiator:close(); end transfers[session.sha] = nil; end @@ -109,7 +109,8 @@ function module.add_host(module) local origin, stanza = event.origin, event.stanza; -- check ACL - while proxy_acl and #proxy_acl > 0 do -- using 'while' instead of 'if' so we can break out of it + -- using 'while' instead of 'if' so we can break out of it + while proxy_acl and #proxy_acl > 0 do --luacheck: ignore 512 local jid = stanza.attr.from; local allow; for _, acl in ipairs(proxy_acl) do @@ -123,7 +124,7 @@ function module.add_host(module) local sid = stanza.tags[1].attr.sid; origin.send(st.reply(stanza):tag("query", {xmlns="http://jabber.org/protocol/bytestreams", sid=sid}) - :tag("streamhost", {jid=host, host=proxy_address, port=proxy_port})); + :tag("streamhost", {jid=host, host=proxy_address, port=("%d"):format(proxy_port)})); return true; end); diff --git a/plugins/mod_pubsub/mod_pubsub.lua b/plugins/mod_pubsub/mod_pubsub.lua index 8e7bfc53..82c787aa 100644 --- a/plugins/mod_pubsub/mod_pubsub.lua +++ b/plugins/mod_pubsub/mod_pubsub.lua @@ -2,6 +2,7 @@ local pubsub = require "util.pubsub"; local st = require "util.stanza"; local jid_bare = require "util.jid".bare; local usermanager = require "core.usermanager"; +local new_id = require "util.id".medium; local xmlns_pubsub = "http://jabber.org/protocol/pubsub"; local xmlns_pubsub_event = "http://jabber.org/protocol/pubsub#event"; @@ -12,32 +13,35 @@ local autocreate_on_subscribe = module:get_option_boolean("autocreate_on_subscri local pubsub_disco_name = module:get_option_string("name", "Prosody PubSub Service"); local expose_publisher = module:get_option_boolean("expose_publisher", false) +local enable_persistence = module:get_option_boolean("experimental_pubsub_item_persistence", false); + local service; local lib_pubsub = module:require "pubsub"; -local handlers = lib_pubsub.handlers; -local pubsub_error_reply = lib_pubsub.pubsub_error_reply; module:depends("disco"); module:add_identity("pubsub", "service", pubsub_disco_name); module:add_feature("http://jabber.org/protocol/pubsub"); function handle_pubsub_iq(event) - local origin, stanza = event.origin, event.stanza; - local pubsub = stanza.tags[1]; - local action = pubsub.tags[1]; - if not action then - origin.send(st.error_reply(stanza, "cancel", "bad-request")); - return true; - end - local handler = handlers[stanza.attr.type.."_"..action.name]; - if handler then - handler(origin, stanza, action, service); - return true; - end + return lib_pubsub.handle_pubsub_iq(event, service); +end + +local node_store = module:open_store(module.name.."_nodes"); + +local function create_simple_itemstore(node_config, node_name) + local archive = module:open_store("pubsub_"..node_name, "archive"); + return lib_pubsub.archive_itemstore(archive, node_config, nil, node_name); +end + +if enable_persistence then + module:log("warn", "Item persistence is an experimental feature. Note that ownership information is lost on restart.") +else + create_simple_itemstore = nil; end -function simple_broadcast(kind, node, jids, item, actor) + +function simple_broadcast(kind, node, jids, item, actor, node_obj) if item then item = st.clone(item); item.attr.xmlns = nil; -- Clear the pubsub namespace @@ -45,52 +49,55 @@ function simple_broadcast(kind, node, jids, item, actor) item.attr.publisher = actor end end - local message = st.message({ from = module.host, type = "headline" }) + + local id = new_id(); + local msg_type = node_obj and node_obj.config.message_type or "headline"; + local message = st.message({ from = module.host, type = msg_type, id = id }) :tag("event", { xmlns = xmlns_pubsub_event }) :tag(kind, { node = node }) - :add_child(item); - for jid in pairs(jids) do - module:log("debug", "Sending notification to %s", jid); - message.attr.to = jid; - module:send(message); - end -end -module:hook("iq/host/"..xmlns_pubsub..":pubsub", handle_pubsub_iq); -module:hook("iq/host/"..xmlns_pubsub_owner..":pubsub", handle_pubsub_iq); + if item then + message:add_child(item); + end -local feature_map = { - create = { "create-nodes", "instant-nodes", "item-ids" }; - retract = { "delete-items", "retract-items" }; - purge = { "purge-nodes" }; - publish = { "publish", autocreate_on_publish and "auto-create" }; - delete = { "delete-nodes" }; - get_items = { "retrieve-items" }; - add_subscription = { "subscribe" }; - get_subscriptions = { "retrieve-subscriptions" }; - set_configure = { "config-node" }; - get_default = { "retrieve-default" }; -}; - -local function add_disco_features_from_service(service) - for method, features in pairs(feature_map) do - if service[method] then - for _, feature in ipairs(features) do - if feature then - module:add_feature(xmlns_pubsub.."#"..feature); + -- Compose a sensible textual representation of at least Atom payloads + if node_obj and item and node_obj.config.include_body and item.tags[1] then + local payload = item.tags[1]; + if payload.attr.xmlns == "http://www.w3.org/2005/Atom" then + message:reset(); + local title = payload:get_child_text("title"); + local summary = payload:get_child_text("summary"); + if not summary and title then + local author = payload:find("author/name#"); + summary = title; + if author then + summary = author .. " posted " .. summary; end end + if summary then + message:body(summary); + end end end - for affiliation in pairs(service.config.capabilities) do - if affiliation ~= "none" and affiliation ~= "owner" then - module:add_feature(xmlns_pubsub.."#"..affiliation.."-affiliation"); - end + + module:broadcast(jids, message, pairs); +end + +function is_item_stanza(item) + return st.is_stanza(item) and item.attr.xmlns == xmlns_pubsub and item.name == "item"; +end + +module:hook("iq/host/"..xmlns_pubsub..":pubsub", handle_pubsub_iq); +module:hook("iq/host/"..xmlns_pubsub_owner..":pubsub", handle_pubsub_iq); + +local function add_disco_features_from_service(service) --luacheck: ignore 431/service + for feature in lib_pubsub.get_feature_set(service) do + module:add_feature(xmlns_pubsub.."#"..feature); end end module:hook("host-disco-info-node", function (event) - local stanza, origin, reply, node = event.stanza, event.origin, event.reply, event.node; + local stanza, reply, node = event.stanza, event.reply, event.node; local ok, ret = service:get_nodes(stanza.attr.from); if not ok or not ret[node] then return; @@ -100,7 +107,7 @@ module:hook("host-disco-info-node", function (event) end); module:hook("host-disco-items-node", function (event) - local stanza, origin, reply, node = event.stanza, event.origin, event.reply, event.node; + local stanza, reply, node = event.stanza, event.reply, event.node; local ok, ret = service:get_items(node, stanza.attr.from); if not ok then return; @@ -114,8 +121,8 @@ end); module:hook("host-disco-items", function (event) - local stanza, origin, reply = event.stanza, event.origin, event.reply; - local ok, ret = service:get_nodes(event.stanza.attr.from); + local stanza, reply = event.stanza, event.reply; + local ok, ret = service:get_nodes(stanza.attr.from); if not ok then return; end @@ -225,7 +232,10 @@ function module.load() autocreate_on_publish = autocreate_on_publish; autocreate_on_subscribe = autocreate_on_subscribe; + nodestore = node_store; + itemstore = create_simple_itemstore; broadcaster = simple_broadcast; + itemcheck = is_item_stanza; get_affiliation = get_affiliation; normalize_jid = jid_bare; diff --git a/plugins/mod_pubsub/pubsub.lib.lua b/plugins/mod_pubsub/pubsub.lib.lua index 1497c21c..2ec6e8de 100644 --- a/plugins/mod_pubsub/pubsub.lib.lua +++ b/plugins/mod_pubsub/pubsub.lib.lua @@ -1,4 +1,10 @@ +local t_unpack = table.unpack or unpack; -- luacheck: ignore 113 +local time_now = os.time; + +local jid_prep = require "util.jid".prep; +local set = require "util.set"; local st = require "util.stanza"; +local it = require "util.iterators"; local uuid_generate = require "util.uuid".generate; local dataform = require"util.dataforms".new; @@ -23,7 +29,7 @@ local pubsub_errors = { }; local function pubsub_error_reply(stanza, error) local e = pubsub_errors[error]; - local reply = st.error_reply(stanza, unpack(e, 1, 3)); + local reply = st.error_reply(stanza, t_unpack(e, 1, 3)); if e[4] then reply:tag(e[4], { xmlns = xmlns_pubsub_errors }):up(); end @@ -31,7 +37,7 @@ local function pubsub_error_reply(stanza, error) end _M.pubsub_error_reply = pubsub_error_reply; -local node_config_form = require"util.dataforms".new { +local node_config_form = dataform { { type = "hidden"; name = "FORM_TYPE"; @@ -42,18 +48,113 @@ local node_config_form = require"util.dataforms".new { name = "pubsub#max_items"; label = "Max # of items to persist"; }; + { + type = "boolean"; + name = "pubsub#persist_items"; + label = "Persist items to storage"; + }; + { + type = "boolean"; + name = "pubsub#include_body"; + label = "Receive message body in addition to payload?"; + }; + { + type = "list-single"; + name = "pubsub#notification_type"; + label = "Specify the delivery style for notifications"; + options = { + { label = "Messages of type normal", value = "normal" }, + { label = "Messages of type headline", value = "headline", default = true }, + }; + }; +}; + +local options_form = dataform { + { + type = "hidden"; + name = "FORM_TYPE"; + value = "http://jabber.org/protocol/pubsub#subscribe_options"; + } + -- No options yet. File a feature request ;) +}; + +local service_method_feature_map = { + add_subscription = { "subscribe" }; + create = { "create-nodes", "instant-nodes", "item-ids", "create-and-configure" }; + delete = { "delete-nodes" }; + get_items = { "retrieve-items" }; + get_subscriptions = { "retrieve-subscriptions" }; + node_defaults = { "retrieve-default" }; + publish = { "publish" }; + purge = { "purge-nodes" }; + retract = { "delete-items", "retract-items" }; + set_node_config = { "config-node" }; + set_affiliation = { "modify-affiliations" }; }; +local service_config_feature_map = { + autocreate_on_publish = { "auto-create" }; +}; + +function _M.get_feature_set(service) + local supported_features = set.new(); + + for method, features in pairs(service_method_feature_map) do + if service[method] then + for _, feature in ipairs(features) do + if feature then + supported_features:add(feature); + end + end + end + end + + for option, features in pairs(service_config_feature_map) do + if service.config[option] then + for _, feature in ipairs(features) do + if feature then + supported_features:add(feature); + end + end + end + end + + for affiliation in pairs(service.config.capabilities) do + if affiliation ~= "none" and affiliation ~= "owner" then + supported_features:add(affiliation.."-affiliation"); + end + end + + return supported_features; +end + +function _M.handle_pubsub_iq(event, service) + local origin, stanza = event.origin, event.stanza; + local pubsub_tag = stanza.tags[1]; + local action = pubsub_tag.tags[1]; + if not action then + return origin.send(st.error_reply(stanza, "cancel", "bad-request")); + end + local prefix = ""; + if pubsub_tag.attr.xmlns == xmlns_pubsub_owner then + prefix = "owner_"; + end + local handler = handlers[prefix..stanza.attr.type.."_"..action.name]; + if handler then + handler(origin, stanza, action, service); + return true; + end +end function handlers.get_items(origin, stanza, items, service) local node = items.attr.node; local item = items:get_child("item"); - local id = item and item.attr.id; + local item_id = item and item.attr.id; if not node then origin.send(pubsub_error_reply(stanza, "nodeid-required")); return true; end - local ok, results = service:get_items(node, stanza.attr.from, id); + local ok, results = service:get_items(node, stanza.attr.from, item_id); if not ok then origin.send(pubsub_error_reply(stanza, results)); return true; @@ -95,8 +196,28 @@ end function handlers.set_create(origin, stanza, create, service) local node = create.attr.node; local ok, ret, reply; + local config; + local configure = stanza.tags[1]:get_child("configure"); + if configure then + local config_form = configure:get_child("x", "jabber:x:data"); + if not config_form then + origin.send(st.error_reply(stanza, "modify", "bad-request", "Missing dataform")); + return true; + end + local form_data, err = node_config_form:data(config_form); + if not form_data then + origin.send(st.error_reply(stanza, "modify", "bad-request", err)); + return true; + end + config = { + ["max_items"] = tonumber(form_data["pubsub#max_items"]); + ["persist_items"] = form_data["pubsub#persist_items"]; + ["notification_type"] = form_data["pubsub#notification_type"]; + ["include_body"] = form_data["pubsub#include_body"]; + }; + end if node then - ok, ret = service:create(node, stanza.attr.from); + ok, ret = service:create(node, stanza.attr.from, config); if ok then reply = st.reply(stanza); else @@ -105,7 +226,7 @@ function handlers.set_create(origin, stanza, create, service) else repeat node = uuid_generate(); - ok, ret = service:create(node, stanza.attr.from); + ok, ret = service:create(node, stanza.attr.from, config); until ok or ret ~= "conflict"; if ok then reply = st.reply(stanza) @@ -119,10 +240,10 @@ function handlers.set_create(origin, stanza, create, service) return true; end -function handlers.set_delete(origin, stanza, delete, service) +function handlers.owner_set_delete(origin, stanza, delete, service) local node = delete.attr.node; - local reply, notifier; + local reply; if not node then origin.send(pubsub_error_reply(stanza, "nodeid-required")); return true; @@ -139,17 +260,15 @@ end function handlers.set_subscribe(origin, stanza, subscribe, service) local node, jid = subscribe.attr.node, subscribe.attr.jid; + jid = jid_prep(jid); if not (node and jid) then origin.send(pubsub_error_reply(stanza, jid and "nodeid-required" or "invalid-jid")); return true; end - --[[ local options_tag, options = stanza.tags[1]:get_child("options"), nil; if options_tag then options = options_form:data(options_tag.tags[1]); end - --]] - local options_tag, options; -- FIXME local ok, ret = service:add_subscription(node, stanza.attr.from, jid, options); local reply; if ok then @@ -171,6 +290,7 @@ end function handlers.set_unsubscribe(origin, stanza, unsubscribe, service) local node, jid = unsubscribe.attr.node, unsubscribe.attr.jid; + jid = jid_prep(jid); if not (node and jid) then origin.send(pubsub_error_reply(stanza, jid and "nodeid-required" or "invalid-jid")); return true; @@ -203,6 +323,9 @@ function handlers.set_publish(origin, stanza, publish, service) local ok, ret = service:publish(node, stanza.attr.from, id, item); local reply; if ok then + if type(ok) == "string" then + id = ok; + end reply = st.reply(stanza) :tag("pubsub", { xmlns = xmlns_pubsub }) :tag("publish", { node = node }) @@ -237,7 +360,7 @@ function handlers.set_retract(origin, stanza, retract, service) return true; end -function handlers.set_purge(origin, stanza, purge, service) +function handlers.owner_set_purge(origin, stanza, purge, service) local node, notify = purge.attr.node, purge.attr.notify; notify = (notify == "1") or (notify == "true"); local reply; @@ -255,7 +378,7 @@ function handlers.set_purge(origin, stanza, purge, service) return true; end -function handlers.get_configure(origin, stanza, config, service) +function handlers.owner_get_configure(origin, stanza, config, service) local node = config.attr.node; if not node then origin.send(pubsub_error_reply(stanza, "nodeid-required")); @@ -273,15 +396,22 @@ function handlers.get_configure(origin, stanza, config, service) return true; end + local node_config = node_obj.config; + local pubsub_form_data = { + ["pubsub#max_items"] = tostring(node_config["max_items"]); + ["pubsub#persist_items"] = node_config["persist_items"]; + ["pubsub#notification_type"] = node_config["notification_type"]; + ["pubsub#include_body"] = node_config["include_body"]; + } local reply = st.reply(stanza) :tag("pubsub", { xmlns = xmlns_pubsub_owner }) :tag("configure", { node = node }) - :add_child(node_config_form:form(node_obj.config)); + :add_child(node_config_form:form(pubsub_form_data)); origin.send(reply); return true; end -function handlers.set_configure(origin, stanza, config, service) +function handlers.owner_set_configure(origin, stanza, config, service) local node = config.attr.node; if not node then origin.send(pubsub_error_reply(stanza, "nodeid-required")); @@ -291,11 +421,20 @@ function handlers.set_configure(origin, stanza, config, service) origin.send(pubsub_error_reply(stanza, "forbidden")); return true; end - local new_config, err = node_config_form:data(config.tags[1]); - if not new_config then + local config_form = config:get_child("x", "jabber:x:data"); + if not config_form then + origin.send(st.error_reply(stanza, "modify", "bad-request", "Missing dataform")); + return true; + end + local form_data, err = node_config_form:data(config_form); + if not form_data then origin.send(st.error_reply(stanza, "modify", "bad-request", err)); return true; end + local new_config = { + ["max_items"] = tonumber(form_data["pubsub#max_items"]); + ["persist_items"] = form_data["pubsub#persist_items"]; + }; local ok, err = service:set_node_config(node, stanza.attr.from, new_config); if not ok then origin.send(pubsub_error_reply(stanza, err)); @@ -305,13 +444,164 @@ function handlers.set_configure(origin, stanza, config, service) return true; end -function handlers.get_default(origin, stanza, default, service) +function handlers.owner_get_default(origin, stanza, default, service) -- luacheck: ignore 212/default + local pubsub_form_data = { + ["pubsub#max_items"] = tostring(service.node_defaults["max_items"]); + ["pubsub#persist_items"] = service.node_defaults["persist_items"] + } local reply = st.reply(stanza) :tag("pubsub", { xmlns = xmlns_pubsub_owner }) :tag("default") - :add_child(node_config_form:form(service.node_defaults)); + :add_child(node_config_form:form(pubsub_form_data)); + origin.send(reply); + return true; +end + +function handlers.owner_get_affiliations(origin, stanza, affiliations, service) + local node = affiliations.attr.node; + if not node then + origin.send(pubsub_error_reply(stanza, "nodeid-required")); + return true; + end + if not service:may(node, stanza.attr.from, "set_affiliation") then + origin.send(pubsub_error_reply(stanza, "forbidden")); + return true; + end + + local node_obj = service.nodes[node]; + if not node_obj then + origin.send(pubsub_error_reply(stanza, "item-not-found")); + return true; + end + + local reply = st.reply(stanza) + :tag("pubsub", { xmlns = xmlns_pubsub_owner }) + :tag("affiliations", { node = node }); + + for jid, affiliation in pairs(node_obj.affiliations) do + reply:tag("affiliation", { jid = jid, affiliation = affiliation }):up(); + end + origin.send(reply); return true; end +function handlers.owner_set_affiliations(origin, stanza, affiliations, service) + local node = affiliations.attr.node; + if not node then + origin.send(pubsub_error_reply(stanza, "nodeid-required")); + return true; + end + if not service:may(node, stanza.attr.from, "set_affiliation") then + origin.send(pubsub_error_reply(stanza, "forbidden")); + return true; + end + + local node_obj = service.nodes[node]; + if not node_obj then + origin.send(pubsub_error_reply(stanza, "item-not-found")); + return true; + end + + for affiliation_tag in affiliations:childtags("affiliation") do + local jid = affiliation_tag.attr.jid; + local affiliation = affiliation_tag.attr.affiliation; + + jid = jid_prep(jid); + if affiliation == "none" then affiliation = nil; end + + local ok, err = service:set_affiliation(node, stanza.attr.from, jid, affiliation); + if not ok then + -- FIXME Incomplete error handling, + -- see XEP 60 8.9.2.4 Multiple Simultaneous Modifications + origin.send(pubsub_error_reply(stanza, err)); + return true; + end + end + + local reply = st.reply(stanza); + origin.send(reply); + return true; +end + +local function create_encapsulating_item(id, payload) + local item = st.stanza("item", { id = id, xmlns = xmlns_pubsub }); + item:add_child(payload); + return item; +end + +local function archive_itemstore(archive, config, user, node) + module:log("debug", "Creation of itemstore for node %s with config %s", node, config); + local get_set = {}; + function get_set:items() -- luacheck: ignore 212/self + local data, err = archive:find(user, { + limit = tonumber(config["max_items"]); + reverse = true; + }); + if not data then + module:log("error", "Unable to get items: %s", err); + return true; + end + module:log("debug", "Listed items %s", data); + return it.reverse(function() + local id, payload, when, publisher = data(); + if id == nil then + return; + end + local item = create_encapsulating_item(id, payload, publisher); + return id, item; + end); + end + function get_set:get(key) -- luacheck: ignore 212/self + local data, err = archive:find(user, { + key = key; + -- Get the last item with that key, if the archive doesn't deduplicate + reverse = true, + limit = 1; + }); + if not data then + module:log("error", "Unable to get item: %s", err); + return nil, err; + end + local id, payload, when, publisher = data(); + module:log("debug", "Get item %s (published at %s by %s)", id, when, publisher); + if id == nil then + return nil; + end + return create_encapsulating_item(id, payload, publisher); + end + function get_set:set(key, value) -- luacheck: ignore 212/self + local data, err; + if value ~= nil then + local publisher = value.attr.publisher; + local payload = value.tags[1]; + data, err = archive:append(user, key, payload, time_now(), publisher); + else + data, err = archive:delete(user, { key = key; }); + end + if not data then + module:log("error", "Unable to set item: %s", err); + return nil, err; + end + return data; + end + function get_set:clear() -- luacheck: ignore 212/self + return archive:delete(user); + end + function get_set:resize(size) -- luacheck: ignore 212/self + return archive:delete(user, { + truncate = size; + }); + end + function get_set:tail() + -- This should conveniently return the last item + local item = self:get(nil); + if item then + return item.attr.id, item; + end + end + return setmetatable(get_set, archive); +end +_M.archive_itemstore = archive_itemstore; + return _M; diff --git a/plugins/mod_register.lua b/plugins/mod_register.lua index b39ce090..49ff8a38 100644 --- a/plugins/mod_register.lua +++ b/plugins/mod_register.lua @@ -7,288 +7,11 @@ -- -local st = require "util.stanza"; -local dataform_new = require "util.dataforms".new; -local usermanager_user_exists = require "core.usermanager".user_exists; -local usermanager_create_user = require "core.usermanager".create_user; -local usermanager_set_password = require "core.usermanager".set_password; -local usermanager_delete_user = require "core.usermanager".delete_user; -local nodeprep = require "util.encodings".stringprep.nodeprep; -local jid_bare = require "util.jid".bare; -local create_throttle = require "util.throttle".create; -local new_cache = require "util.cache".new; - -local compat = module:get_option_boolean("registration_compat", true); local allow_registration = module:get_option_boolean("allow_registration", false); -local additional_fields = module:get_option("additional_registration_fields", {}); -local require_encryption = module:get_option("c2s_require_encryption") or module:get_option("require_encryption"); - -local account_details = module:open_store("account_details"); - -local field_map = { - username = { name = "username", type = "text-single", label = "Username", required = true }; - password = { name = "password", type = "text-private", label = "Password", required = true }; - nick = { name = "nick", type = "text-single", label = "Nickname" }; - name = { name = "name", type = "text-single", label = "Full Name" }; - first = { name = "first", type = "text-single", label = "Given Name" }; - last = { name = "last", type = "text-single", label = "Family Name" }; - email = { name = "email", type = "text-single", label = "Email" }; - address = { name = "address", type = "text-single", label = "Street" }; - city = { name = "city", type = "text-single", label = "City" }; - state = { name = "state", type = "text-single", label = "State" }; - zip = { name = "zip", type = "text-single", label = "Postal code" }; - phone = { name = "phone", type = "text-single", label = "Telephone number" }; - url = { name = "url", type = "text-single", label = "Webpage" }; - date = { name = "date", type = "text-single", label = "Birth date" }; -}; - -local title = module:get_option_string("registration_title", - "Creating a new account"); -local instructions = module:get_option_string("registration_instructions", - "Choose a username and password for use with this service."); - -local registration_form = dataform_new{ - title = title; - instructions = instructions; - - field_map.username; - field_map.password; -}; - -local registration_query = st.stanza("query", {xmlns = "jabber:iq:register"}) - :tag("instructions"):text(instructions):up() - :tag("username"):up() - :tag("password"):up(); - -for _, field in ipairs(additional_fields) do - if type(field) == "table" then - registration_form[#registration_form + 1] = field; - elseif field_map[field] or field_map[field:sub(1, -2)] then - if field:match("%+$") then - field = field:sub(1, -2); - field_map[field].required = true; - end - registration_form[#registration_form + 1] = field_map[field]; - registration_query:tag(field):up(); - else - module:log("error", "Unknown field %q", field); - end +if allow_registration then + module:depends("register_ibr"); + module:depends("register_limits"); end -registration_query:add_child(registration_form:form()); - -module:add_feature("jabber:iq:register"); - -local register_stream_feature = st.stanza("register", {xmlns="http://jabber.org/features/iq-register"}):up(); -module:hook("stream-features", function(event) - local session, features = event.origin, event.features; - - -- Advertise registration to unauthorized clients only. - if not(allow_registration) or session.type ~= "c2s_unauthed" or (require_encryption and not session.secure) then - return - end - - features:add_child(register_stream_feature); -end); - --- Password change and account deletion handler -local function handle_registration_stanza(event) - local session, stanza = event.origin, event.stanza; - local log = session.log or module._log; - - local query = stanza.tags[1]; - if stanza.attr.type == "get" then - local reply = st.reply(stanza); - reply:tag("query", {xmlns = "jabber:iq:register"}) - :tag("registered"):up() - :tag("username"):text(session.username):up() - :tag("password"):up(); - session.send(reply); - else -- stanza.attr.type == "set" - if query.tags[1] and query.tags[1].name == "remove" then - local username, host = session.username, session.host; - - -- This one weird trick sends a reply to this stanza before the user is deleted - local old_session_close = session.close; - session.close = function(self, ...) - self.send(st.reply(stanza)); - return old_session_close(self, ...); - end - - local ok, err = usermanager_delete_user(username, host); - - if not ok then - log("debug", "Removing user account %s@%s failed: %s", username, host, err); - session.close = old_session_close; - session.send(st.error_reply(stanza, "cancel", "service-unavailable", err)); - return true; - end - - log("info", "User removed their account: %s@%s", username, host); - module:fire_event("user-deregistered", { username = username, host = host, source = "mod_register", session = session }); - else - local username = nodeprep(query:get_child_text("username")); - local password = query:get_child_text("password"); - if username and password then - if username == session.username then - if usermanager_set_password(username, password, session.host, session.resource) then - session.send(st.reply(stanza)); - else - -- TODO unable to write file, file may be locked, etc, what's the correct error? - session.send(st.error_reply(stanza, "wait", "internal-server-error")); - end - else - session.send(st.error_reply(stanza, "modify", "bad-request")); - end - else - session.send(st.error_reply(stanza, "modify", "bad-request")); - end - end - end - return true; -end - -module:hook("iq/self/jabber:iq:register:query", handle_registration_stanza); -if compat then - module:hook("iq/host/jabber:iq:register:query", function (event) - local session, stanza = event.origin, event.stanza; - if session.type == "c2s" and jid_bare(stanza.attr.to) == session.host then - return handle_registration_stanza(event); - end - end); -end - -local function parse_response(query) - local form = query:get_child("x", "jabber:x:data"); - if form then - return registration_form:data(form); - else - local data = {}; - local errors = {}; - for _, field in ipairs(registration_form) do - local name, required = field.name, field.required; - if field_map[name] then - data[name] = query:get_child_text(name); - if (not data[name] or #data[name] == 0) and required then - errors[name] = "Required value missing"; - end - end - end - if next(errors) then - return data, errors; - end - return data; - end -end - -local min_seconds_between_registrations = module:get_option_number("min_seconds_between_registrations"); -local whitelist_only = module:get_option_boolean("whitelist_registration_only"); -local whitelisted_ips = module:get_option_set("registration_whitelist", { "127.0.0.1", "::1" })._items; -local blacklisted_ips = module:get_option_set("registration_blacklist", {})._items; - -local throttle_max = module:get_option_number("registration_throttle_max", min_seconds_between_registrations and 1); -local throttle_period = module:get_option_number("registration_throttle_period", min_seconds_between_registrations); -local throttle_cache_size = module:get_option_number("registration_throttle_cache_size", 100); -local blacklist_overflow = module:get_option_boolean("blacklist_on_registration_throttle_overload", false); - -local throttle_cache = new_cache(throttle_cache_size, blacklist_overflow and function (ip, throttle) - if not throttle:peek() then - module:log("info", "Adding ip %s to registration blacklist", ip); - blacklisted_ips[ip] = true; - end -end or nil); - -local function check_throttle(ip) - if not throttle_max then return true end - local throttle = throttle_cache:get(ip); - if not throttle then - throttle = create_throttle(throttle_max, throttle_period); - end - throttle_cache:set(ip, throttle); - return throttle:poll(1); -end - --- In-band registration -module:hook("stanza/iq/jabber:iq:register:query", function(event) - local session, stanza = event.origin, event.stanza; - local log = session.log or module._log; - if not(allow_registration) or session.type ~= "c2s_unauthed" then - log("debug", "Attempted registration when disabled or already authenticated"); - session.send(st.error_reply(stanza, "cancel", "service-unavailable")); - elseif require_encryption and not session.secure then - session.send(st.error_reply(stanza, "modify", "policy-violation", "Encryption is required")); - else - local query = stanza.tags[1]; - if stanza.attr.type == "get" then - local reply = st.reply(stanza); - reply:add_child(registration_query); - session.send(reply); - elseif stanza.attr.type == "set" then - if query.tags[1] and query.tags[1].name == "remove" then - session.send(st.error_reply(stanza, "auth", "registration-required")); - else - local data, errors = parse_response(query); - if errors then - log("debug", "Error parsing registration form:"); - for field, err in pairs(errors) do - log("debug", "Field %q: %s", field, err); - end - session.send(st.error_reply(stanza, "modify", "not-acceptable")); - else - -- Check that the user is not blacklisted or registering too often - if not session.ip then - log("debug", "User's IP not known; can't apply blacklist/whitelist"); - elseif blacklisted_ips[session.ip] or (whitelist_only and not whitelisted_ips[session.ip]) then - session.send(st.error_reply(stanza, "cancel", "not-acceptable", "You are not allowed to register an account.")); - return true; - elseif throttle_max and not whitelisted_ips[session.ip] then - if not check_throttle(session.ip) then - log("debug", "Registrations over limit for ip %s", session.ip or "?"); - session.send(st.error_reply(stanza, "wait", "not-acceptable")); - return true; - end - end - local username, password = nodeprep(data.username), data.password; - data.username, data.password = nil, nil; - local host = module.host; - if not username or username == "" then - log("debug", "The requested username is invalid."); - session.send(st.error_reply(stanza, "modify", "not-acceptable", "The requested username is invalid.")); - return true; - end - local user = { username = username , host = host, additional = data, allowed = true } - module:fire_event("user-registering", user); - if not user.allowed then - log("debug", "Registration disallowed by module"); - session.send(st.error_reply(stanza, "modify", "not-acceptable", "The requested username is forbidden.")); - elseif usermanager_user_exists(username, host) then - log("debug", "Attempt to register with existing username"); - session.send(st.error_reply(stanza, "cancel", "conflict", "The requested username already exists.")); - else - -- TODO unable to write file, file may be locked, etc, what's the correct error? - local error_reply = st.error_reply(stanza, "wait", "internal-server-error", "Failed to write data to disk."); - if usermanager_create_user(username, password, host) then - data.registered = os.time(); - if not account_details:set(username, data) then - log("debug", "Could not store extra details"); - usermanager_delete_user(username, host); - session.send(error_reply); - return true; - end - session.send(st.reply(stanza)); -- user created! - log("info", "User account created: %s@%s", username, host); - module:fire_event("user-registered", { - username = username, host = host, source = "mod_register", - session = session }); - else - log("debug", "Could not create user"); - session.send(error_reply); - end - end - end - end - end - end - return true; -end); +module:depends("user_account_management"); diff --git a/plugins/mod_register_ibr.lua b/plugins/mod_register_ibr.lua new file mode 100644 index 00000000..dc7168b4 --- /dev/null +++ b/plugins/mod_register_ibr.lua @@ -0,0 +1,195 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + + +local st = require "util.stanza"; +local dataform_new = require "util.dataforms".new; +local usermanager_user_exists = require "core.usermanager".user_exists; +local usermanager_create_user = require "core.usermanager".create_user; +local usermanager_delete_user = require "core.usermanager".delete_user; +local nodeprep = require "util.encodings".stringprep.nodeprep; + +local additional_fields = module:get_option("additional_registration_fields", {}); +local require_encryption = module:get_option_boolean("c2s_require_encryption", + module:get_option_boolean("require_encryption", false)); + +local account_details = module:open_store("account_details"); + +local field_map = { + username = { name = "username", type = "text-single", label = "Username", required = true }; + password = { name = "password", type = "text-private", label = "Password", required = true }; + nick = { name = "nick", type = "text-single", label = "Nickname" }; + name = { name = "name", type = "text-single", label = "Full Name" }; + first = { name = "first", type = "text-single", label = "Given Name" }; + last = { name = "last", type = "text-single", label = "Family Name" }; + email = { name = "email", type = "text-single", label = "Email" }; + address = { name = "address", type = "text-single", label = "Street" }; + city = { name = "city", type = "text-single", label = "City" }; + state = { name = "state", type = "text-single", label = "State" }; + zip = { name = "zip", type = "text-single", label = "Postal code" }; + phone = { name = "phone", type = "text-single", label = "Telephone number" }; + url = { name = "url", type = "text-single", label = "Webpage" }; + date = { name = "date", type = "text-single", label = "Birth date" }; +}; + +local title = module:get_option_string("registration_title", + "Creating a new account"); +local instructions = module:get_option_string("registration_instructions", + "Choose a username and password for use with this service."); + +local registration_form = dataform_new{ + title = title; + instructions = instructions; + + field_map.username; + field_map.password; +}; + +local registration_query = st.stanza("query", {xmlns = "jabber:iq:register"}) + :tag("instructions"):text(instructions):up() + :tag("username"):up() + :tag("password"):up(); + +for _, field in ipairs(additional_fields) do + if type(field) == "table" then + registration_form[#registration_form + 1] = field; + elseif field_map[field] or field_map[field:sub(1, -2)] then + if field:match("%+$") then + field = field:sub(1, -2); + field_map[field].required = true; + end + + registration_form[#registration_form + 1] = field_map[field]; + registration_query:tag(field):up(); + else + module:log("error", "Unknown field %q", field); + end +end +registration_query:add_child(registration_form:form()); + +local register_stream_feature = st.stanza("register", {xmlns="http://jabber.org/features/iq-register"}):up(); +module:hook("stream-features", function(event) + local session, features = event.origin, event.features; + + -- Advertise registration to unauthorized clients only. + if session.type ~= "c2s_unauthed" or (require_encryption and not session.secure) then + return + end + + features:add_child(register_stream_feature); +end); + +local function parse_response(query) + local form = query:get_child("x", "jabber:x:data"); + if form then + return registration_form:data(form); + else + local data = {}; + local errors = {}; + for _, field in ipairs(registration_form) do + local name, required = field.name, field.required; + if field_map[name] then + data[name] = query:get_child_text(name); + if (not data[name] or #data[name] == 0) and required then + errors[name] = "Required value missing"; + end + end + end + if next(errors) then + return data, errors; + end + return data; + end +end + +-- In-band registration +module:hook("stanza/iq/jabber:iq:register:query", function(event) + local session, stanza = event.origin, event.stanza; + local log = session.log or module._log; + + if session.type ~= "c2s_unauthed" then + log("debug", "Attempted registration when disabled or already authenticated"); + session.send(st.error_reply(stanza, "cancel", "service-unavailable")); + return true; + end + + if require_encryption and not session.secure then + session.send(st.error_reply(stanza, "modify", "policy-violation", "Encryption is required")); + return true; + end + + local query = stanza.tags[1]; + if stanza.attr.type == "get" then + local reply = st.reply(stanza); + reply:add_child(registration_query); + session.send(reply); + return true; + end + + -- stanza.attr.type == "set" + if query.tags[1] and query.tags[1].name == "remove" then + session.send(st.error_reply(stanza, "auth", "registration-required")); + return true; + end + + local data, errors = parse_response(query); + if errors then + log("debug", "Error parsing registration form:"); + local textual_errors = {}; + for field, err in pairs(errors) do + log("debug", "Field %q: %s", field, err); + table.insert(textual_errors, ("%s: %s"):format(field:gsub("^%a", string.upper), err)); + end + session.send(st.error_reply(stanza, "modify", "not-acceptable", table.concat(textual_errors, "\n"))); + return true; + end + + local username, password = nodeprep(data.username), data.password; + data.username, data.password = nil, nil; + local host = module.host; + if not username or username == "" then + log("debug", "The requested username is invalid."); + session.send(st.error_reply(stanza, "modify", "not-acceptable", "The requested username is invalid.")); + return true; + end + + local user = { username = username , host = host, additional = data, ip = session.ip, session = session, allowed = true } + module:fire_event("user-registering", user); + if not user.allowed then + log("debug", "Registration disallowed by module: %s", user.reason or "no reason given"); + session.send(st.error_reply(stanza, "modify", "not-acceptable", user.reason)); + return true; + end + + if usermanager_user_exists(username, host) then + log("debug", "Attempt to register with existing username"); + session.send(st.error_reply(stanza, "cancel", "conflict", "The requested username already exists.")); + return true; + end + + -- TODO unable to write file, file may be locked, etc, what's the correct error? + local error_reply = st.error_reply(stanza, "wait", "internal-server-error", "Failed to write data to disk."); + if usermanager_create_user(username, password, host) then + data.registered = os.time(); + if not account_details:set(username, data) then + log("debug", "Could not store extra details"); + usermanager_delete_user(username, host); + session.send(error_reply); + return true; + end + session.send(st.reply(stanza)); -- user created! + log("info", "User account created: %s@%s", username, host); + module:fire_event("user-registered", { + username = username, host = host, source = "mod_register", + session = session }); + else + log("debug", "Could not create user"); + session.send(error_reply); + end + return true; +end); diff --git a/plugins/mod_register_limits.lua b/plugins/mod_register_limits.lua new file mode 100644 index 00000000..736282a5 --- /dev/null +++ b/plugins/mod_register_limits.lua @@ -0,0 +1,78 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + + +local create_throttle = require "util.throttle".create; +local new_cache = require "util.cache".new; +local ip_util = require "util.ip"; +local new_ip = ip_util.new_ip; +local match_ip = ip_util.match; +local parse_cidr = ip_util.parse_cidr; + +local min_seconds_between_registrations = module:get_option_number("min_seconds_between_registrations"); +local whitelist_only = module:get_option_boolean("whitelist_registration_only"); +local whitelisted_ips = module:get_option_set("registration_whitelist", { "127.0.0.1", "::1" })._items; +local blacklisted_ips = module:get_option_set("registration_blacklist", {})._items; + +local throttle_max = module:get_option_number("registration_throttle_max", min_seconds_between_registrations and 1); +local throttle_period = module:get_option_number("registration_throttle_period", min_seconds_between_registrations); +local throttle_cache_size = module:get_option_number("registration_throttle_cache_size", 100); +local blacklist_overflow = module:get_option_boolean("blacklist_on_registration_throttle_overload", false); + +local throttle_cache = new_cache(throttle_cache_size, blacklist_overflow and function (ip, throttle) + if not throttle:peek() then + module:log("info", "Adding ip %s to registration blacklist", ip); + blacklisted_ips[ip] = true; + end +end or nil); + +local function check_throttle(ip) + if not throttle_max then return true end + local throttle = throttle_cache:get(ip); + if not throttle then + throttle = create_throttle(throttle_max, throttle_period); + end + throttle_cache:set(ip, throttle); + return throttle:poll(1); +end + +local function ip_in_set(set, ip) + if set[ip] then + return true; + end + ip = new_ip(ip); + for in_set in pairs(set) do + if match_ip(ip, parse_cidr(in_set)) then + return true; + end + end + return false; +end + +module:hook("user-registering", function (event) + local session = event.session; + local ip = event.ip or session and session.ip; + local log = session and session.log or module._log; + if not ip then + log("warn", "IP not known; can't apply blacklist/whitelist"); + elseif ip_in_set(blacklisted_ips, ip) then + log("debug", "Registration disallowed by blacklist"); + event.allowed = false; + event.reason = "Your IP address is blacklisted"; + elseif (whitelist_only and not ip_in_set(whitelisted_ips, ip)) then + log("debug", "Registration disallowed by whitelist"); + event.allowed = false; + event.reason = "Your IP address is not whitelisted"; + elseif throttle_max and not ip_in_set(whitelisted_ips, ip) then + if not check_throttle(ip) then + log("debug", "Registrations over limit for ip %s", ip or "?"); + event.allowed = false; + event.reason = "Too many registrations from this IP address recently"; + end + end +end); diff --git a/plugins/mod_roster.lua b/plugins/mod_roster.lua index 24c50678..39d59cbd 100644 --- a/plugins/mod_roster.lua +++ b/plugins/mod_roster.lua @@ -11,9 +11,8 @@ local st = require "util.stanza" local jid_split = require "util.jid".split; local jid_prep = require "util.jid".prep; -local t_concat = table.concat; local tonumber = tonumber; -local pairs, ipairs = pairs, ipairs; +local pairs = pairs; local rm_load_roster = require "core.rostermanager".load_roster; local rm_remove_from_roster = require "core.rostermanager".remove_from_roster; @@ -51,7 +50,7 @@ module:hook("iq/self/jabber:iq:roster:query", function(event) name = item.name, }); for group in pairs(item.groups) do - roster:tag("group"):text(group):up(); + roster:text_tag("group", group); end roster:up(); -- move out from item end @@ -96,12 +95,10 @@ module:hook("iq/self/jabber:iq:roster:query", function(event) else r_item.subscription = "none"; end - for _, child in ipairs(item) do - if child.name == "group" then - local text = t_concat(child); - if text and text ~= "" then - r_item.groups[text] = true; - end + for group in item:childtags("group") do + local text = group:get_text(); + if text then + r_item.groups[text] = true; end end local success, err_type, err_cond, err_msg = rm_add_to_roster(session, jid, r_item); diff --git a/plugins/mod_s2s/mod_s2s.lua b/plugins/mod_s2s/mod_s2s.lua index 1f38c13a..e5fb8042 100644 --- a/plugins/mod_s2s/mod_s2s.lua +++ b/plugins/mod_s2s/mod_s2s.lua @@ -14,7 +14,7 @@ local core_process_stanza = prosody.core_process_stanza; local tostring, type = tostring, type; local t_insert = table.insert; -local xpcall, traceback = xpcall, debug.traceback; +local traceback = debug.traceback; local add_task = require "util.timer".add_task; local st = require "util.stanza"; @@ -26,6 +26,7 @@ local s2s_new_outgoing = require "core.s2smanager".new_outgoing; local s2s_destroy_session = require "core.s2smanager".destroy_session; local uuid_gen = require "util.uuid".generate; local fire_global_event = prosody.events.fire_event; +local runner = require "util.async".runner; local s2sout = module:require("s2sout"); @@ -38,17 +39,25 @@ local secure_domains, insecure_domains = local require_encryption = module:get_option_boolean("s2s_require_encryption", false); local measure_connections = module:measure("connections", "amount"); +local measure_ipv6 = module:measure("ipv6", "amount"); local sessions = module:shared("sessions"); +local runner_callbacks = {}; + local log = module._log; module:hook("stats-update", function () local count = 0; - for _ in pairs(sessions) do + local ipv6 = 0; + for _, session in pairs(sessions) do count = count + 1; + if session.ip and session.ip:match(":") then + ipv6 = ipv6 + 1; + end end measure_connections(count); + measure_ipv6(ipv6); end); --- Handle stanzas to remote domains @@ -57,13 +66,16 @@ local bouncy_stanzas = { message = true, presence = true, iq = true }; local function bounce_sendq(session, reason) local sendq = session.sendq; if not sendq then return; end - session.log("info", "Sending error replies for "..#sendq.." queued stanzas because of failed outgoing connection to "..tostring(session.to_host)); + session.log("info", "Sending error replies for %d queued stanzas because of failed outgoing connection to %s", #sendq, session.to_host); local dummy = { type = "s2sin"; - send = function(s) + send = function () (session.log or log)("error", "Replying to to an s2s error reply, please report this! Traceback: %s", traceback()); end; dummy = true; + close = function () + (session.log or log)("error", "Attempting to close the dummy origin of s2s error replies, please report this! Traceback: %s", traceback()); + end; }; for i, data in ipairs(sendq) do local reply = data[2]; @@ -100,8 +112,15 @@ function route_to_existing_session(event) (host.log or log)("debug", "trying to send over unauthed s2sout to "..to_host); -- Queue stanza until we are able to send it - if host.sendq then t_insert(host.sendq, {tostring(stanza), stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza)}); - else host.sendq = { {tostring(stanza), stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza)} }; end + local queued_item = { + tostring(stanza), + stanza.attr.type ~= "error" and stanza.attr.type ~= "result" and st.reply(stanza); + }; + if host.sendq then + t_insert(host.sendq, queued_item); + else + host.sendq = { queued_item }; + end host.log("debug", "stanza [%s] queued ", stanza.name); return true; elseif host.type == "local" or host.type == "component" then @@ -114,7 +133,7 @@ function route_to_existing_session(event) -- FIXME if host.from_host ~= from_host then log("error", "WARNING! This might, possibly, be a bug, but it might not..."); - log("error", "We are going to send from %s instead of %s", tostring(host.from_host), tostring(from_host)); + log("error", "We are going to send from %s instead of %s", host.from_host, from_host); end if host.sends2s(stanza) then host.log("debug", "stanza sent over %s", host.type); @@ -151,7 +170,7 @@ module:hook("s2s-read-timeout", keepalive, -1); function module.add_host(module) if module:get_option_boolean("disallow_s2s", false) then - module:log("warn", "The 'disallow_s2s' config option is deprecated, please see http://prosody.im/doc/s2s#disabling"); + module:log("warn", "The 'disallow_s2s' config option is deprecated, please see https://prosody.im/doc/s2s#disabling"); return nil, "This host has disallow_s2s set"; end module:hook("route/remote", route_to_existing_session, -1); @@ -267,11 +286,21 @@ end --- XMPP stream event handlers -local stream_callbacks = { default_ns = "jabber:server", handlestanza = core_process_stanza }; +local stream_callbacks = { default_ns = "jabber:server" }; + +function stream_callbacks.handlestanza(session, stanza) + stanza = session.filter("stanzas/in", stanza); + session.thread:run(stanza); +end local xmlns_xmpp_streams = "urn:ietf:params:xml:ns:xmpp-streams"; function stream_callbacks.streamopened(session, attr) + -- run _streamopened in async context + session.thread:run({ attr = attr }); +end + +function stream_callbacks._streamopened(session, attr) session.version = tonumber(attr.version) or 0; -- TODO: Rename session.secure to session.encrypted @@ -366,7 +395,7 @@ function stream_callbacks.streamopened(session, attr) end if ( session.type == "s2sin" or session.type == "s2sout" ) or features.tags[1] then - log("debug", "Sending stream features: %s", tostring(features)); + log("debug", "Sending stream features: %s", features); session.sends2s(features); else (session.log or log)("warn", "No stream features to offer, giving up"); @@ -423,7 +452,7 @@ function stream_callbacks.error(session, error, data) session.log("debug", "Invalid opening stream header (%s)", (data:gsub("^([^\1]+)\1", "{%1}"))); session:close("invalid-namespace"); elseif error == "parse-error" then - session.log("debug", "Server-to-server XML parse error: %s", tostring(error)); + session.log("debug", "Server-to-server XML parse error: %s", error); session:close("not-well-formed"); elseif error == "stream-error" then local condition, text = "undefined-condition"; @@ -443,14 +472,6 @@ function stream_callbacks.error(session, error, data) end end -local function handleerr(err) log("error", "Traceback[s2s]: %s", traceback(tostring(err), 2)); end -function stream_callbacks.handlestanza(session, stanza) - stanza = session.filter("stanzas/in", stanza); - if stanza then - return xpcall(function () return core_process_stanza(session, stanza) end, handleerr); - end -end - local listener = {}; --- Session methods @@ -478,10 +499,10 @@ local function session_close(session, reason, remote_reason) if reason.extra then stanza:add_child(reason.extra); end - log("debug", "Disconnecting %s[%s], <stream:error> is: %s", session.host or session.ip or "(unknown host)", session.type, tostring(stanza)); + log("debug", "Disconnecting %s[%s], <stream:error> is: %s", session.host or session.ip or "(unknown host)", session.type, stanza); session.sends2s(stanza); elseif reason.name then -- a stanza - log("debug", "Disconnecting %s->%s[%s], <stream:error> is: %s", session.from_host or "(unknown host)", session.to_host or "(unknown host)", session.type, tostring(reason)); + log("debug", "Disconnecting %s->%s[%s], <stream:error> is: %s", session.from_host or "(unknown host)", session.to_host or "(unknown host)", session.type, reason); session.sends2s(reason); end end @@ -525,6 +546,15 @@ end -- Session initialization logic shared by incoming and outgoing local function initialize_session(session) local stream = new_xmpp_stream(session, stream_callbacks); + + session.thread = runner(function (stanza) + if stanza.name == nil then + stream_callbacks._streamopened(session, stanza.attr); + else + core_process_stanza(session, stanza); + end + end, runner_callbacks, session); + local log = session.log or log; session.stream = stream; @@ -588,6 +618,20 @@ local function initialize_session(session) end); end +function runner_callbacks:ready() + self.data.log("debug", "Runner %s ready (%s)", self.thread, coroutine.status(self.thread)); + self.data.conn:resume(); +end + +function runner_callbacks:waiting() + self.data.log("debug", "Runner %s waiting (%s)", self.thread, coroutine.status(self.thread)); + self.data.conn:pause(); +end + +function runner_callbacks:error(err) + (self.data.log or log)("error", "Traceback[s2s]: %s", err); +end + function listener.onconnect(conn) conn:setoption("keepalive", opt_keepalives); local session = sessions[conn]; @@ -629,7 +673,7 @@ function listener.ondisconnect(conn, err) return; -- Session lives for now end end - (session.log or log)("debug", "s2s disconnected: %s->%s (%s)", tostring(session.from_host), tostring(session.to_host), tostring(err or "connection closed")); + (session.log or log)("debug", "s2s disconnected: %s->%s (%s)", session.from_host, session.to_host, err or "connection closed"); s2s_destroy_session(session, err); end end diff --git a/plugins/mod_s2s/s2sout.lib.lua b/plugins/mod_s2s/s2sout.lib.lua index 122ab6a9..1c0cd5ed 100644 --- a/plugins/mod_s2s/s2sout.lib.lua +++ b/plugins/mod_s2s/s2sout.lib.lua @@ -30,6 +30,7 @@ local sources = {}; local has_ipv4, has_ipv6; local dns_timeout = module:get_option_number("dns_timeout", 15); +local resolvers = module:get_option_set("s2s_dns_resolvers") local s2sout = {}; @@ -45,11 +46,18 @@ local function compare_srv_priorities(a,b) end function s2sout.initiate_connection(host_session) + local log = host_session.log or log; + initialize_filters(host_session); host_session.version = 1; host_session.resolver = adns.resolver(); host_session.resolver._resolver:settimeout(dns_timeout); + if resolvers then + for resolver in resolvers do + host_session.resolver._resolver:addnameserver(resolver); + end + end -- Kick the connection attempting machine into life if not s2sout.attempt_connection(host_session) then @@ -68,9 +76,9 @@ function s2sout.initiate_connection(host_session) buffer = {}; host_session.send_buffer = buffer; end - log("debug", "Buffering data on unconnected s2sout to %s", tostring(host_session.to_host)); + log("debug", "Buffering data on unconnected s2sout to %s", host_session.to_host); buffer[#buffer+1] = data; - log("debug", "Buffered item %d: %s", #buffer, tostring(data)); + log("debug", "Buffered item %d: %s", #buffer, data); end end end @@ -78,6 +86,7 @@ end function s2sout.attempt_connection(host_session, err) local to_host = host_session.to_host; local connect_host, connect_port = to_host and idna_to_ascii(to_host), 5269; + local log = host_session.log or log; if not connect_host then return false; @@ -129,16 +138,16 @@ function s2sout.attempt_connection(host_session, err) host_session.srv_choice = host_session.srv_choice + 1; local srv_choice = host_session.srv_hosts[host_session.srv_choice]; connect_host, connect_port = srv_choice.target or to_host, srv_choice.port or connect_port; - host_session.log("info", "Connection failed (%s). Attempt #%d: This time to %s:%d", tostring(err), host_session.srv_choice, connect_host, connect_port); + host_session.log("info", "Connection failed (%s). Attempt #%d: This time to %s:%d", err, host_session.srv_choice, connect_host, connect_port); else - host_session.log("info", "Failed in all attempts to connect to %s", tostring(host_session.to_host)); + host_session.log("info", "Failed in all attempts to connect to %s", host_session.to_host); -- We're out of options return false; end if not (connect_host and connect_port) then -- Likely we couldn't resolve DNS - log("warn", "Hmm, we're without a host (%s) and port (%s) to connect to for %s, giving up :(", tostring(connect_host), tostring(connect_port), tostring(to_host)); + log("warn", "Hmm, we're without a host (%s) and port (%s) to connect to for %s, giving up :(", connect_host, connect_port, to_host); return false; end @@ -160,6 +169,7 @@ end function s2sout.try_connect(host_session, connect_host, connect_port, err) host_session.connecting = true; + local log = host_session.log or log; if not err then local IPs = {}; @@ -246,6 +256,7 @@ function s2sout.try_connect(host_session, connect_host, connect_port, err) elseif host_session.ip_hosts and #host_session.ip_hosts > host_session.ip_choice then -- Not our first attempt, and we also have IPs left to try s2sout.try_next_ip(host_session); else + log("debug", "Out of IP addresses, trying next SRV record (if any)"); host_session.ip_hosts = nil; if not s2sout.attempt_connection(host_session, "out of IP addresses") then -- Retry if we can log("debug", "No other records to try for %s - destroying", host_session.to_host); @@ -259,7 +270,8 @@ function s2sout.try_connect(host_session, connect_host, connect_port, err) end function s2sout.make_connect(host_session, connect_host, connect_port) - (host_session.log or log)("debug", "Beginning new connection attempt to %s ([%s]:%d)", host_session.to_host, connect_host.addr, connect_port); + local log = host_session.log or log; + log("debug", "Beginning new connection attempt to %s ([%s]:%d)", host_session.to_host, connect_host.addr, connect_port); -- Reset secure flag in case this is another -- connection attempt after a failed STARTTLS diff --git a/plugins/mod_storage_internal.lua b/plugins/mod_storage_internal.lua index 76052575..27bd8830 100644 --- a/plugins/mod_storage_internal.lua +++ b/plugins/mod_storage_internal.lua @@ -44,17 +44,36 @@ local archive = {}; driver.archive = { __index = archive }; function archive:append(username, key, value, when, with) - key = key or id(); when = when or now(); if not st.is_stanza(value) then return nil, "unsupported-datatype"; end value = st.preserialize(st.clone(value)); - value.key = key; value.when = when; value.with = with; value.attr.stamp = datetime.datetime(when); value.attr.stamp_legacy = datetime.legacy(when); + + if key then + local items, err = datamanager.list_load(username, host, self.store); + if not items and err then return items, err; end + if items then + items = array(items); + items:filter(function (item) + return item.key ~= key; + end); + value.key = key; + items:push(value); + local ok, err = datamanager.list_store(username, host, self.store, items); + if not ok then return ok, err; end + return key; + end + else + key = id(); + end + + value.key = key; + local ok, err = datamanager.list_append(username, host, self.store, value); if not ok then return ok, err; end return key; @@ -141,9 +160,6 @@ function archive:delete(username, query) if not query or next(query) == nil then return datamanager.list_store(username, host, self.store, nil); end - for k in pairs(query) do - if k ~= "end" then return nil, "unsupported-query-field"; end - end local items, err = datamanager.list_load(username, host, self.store); if not items then if err then @@ -154,10 +170,48 @@ function archive:delete(username, query) end items = array(items); local count_before = #items; - items:filter(function (item) - return item.when > query["end"]; - end); + if query then + if query.key then + items:filter(function (item) + return item.key ~= query.key; + end); + end + if query.with then + items:filter(function (item) + return item.with ~= query.with; + end); + end + if query.start then + items:filter(function (item) + return item.when < query.start; + end); + end + if query["end"] then + items:filter(function (item) + return item.when > query["end"]; + end); + end + if query.truncate then + if query.reverse then + -- Before: { 1, 2, 3, 4, 5, } + -- After: { 1, 2, 3 } + for i = #items, query.truncate + 1, -1 do + items[i] = nil; + end + else + -- Before: { 1, 2, 3, 4, 5, } + -- After: { 3, 4, 5 } + local offset = #items - query.truncate; + for i = 1, #items do + items[i] = items[i+offset]; + end + end + end + end local count = count_before - #items; + if count == 0 then + return 0; -- No changes, skip write + end local ok, err = datamanager.list_store(username, host, self.store, items); if not ok then return ok, err; end return count; diff --git a/plugins/mod_storage_sql.lua b/plugins/mod_storage_sql.lua index 13c961f8..74a9665b 100644 --- a/plugins/mod_storage_sql.lua +++ b/plugins/mod_storage_sql.lua @@ -43,12 +43,17 @@ local function deserialize(t, value) elseif t == "boolean" then if value == "true" then return true; elseif value == "false" then return false; end - elseif t == "number" then return tonumber(value); + return nil, "invalid-boolean"; + elseif t == "number" then + value = tonumber(value); + if value then return value; end + return nil, "invalid-number"; elseif t == "json" then return json.decode(value); elseif t == "xml" then return xml_parse(value); end + return nil, "Unhandled value type: "..t; end local host = module.host; @@ -65,7 +70,8 @@ local function keyval_store_get() for row in engine:select(select_sql, host, user or "", store) do haveany = true; local k = row[1]; - local v = deserialize(row[2], row[3]); + local v, e = deserialize(row[2], row[3]); + assert(v ~= nil, e); if k and v then if k ~= "" then result[k] = v; elseif type(v) == "table" then for a,b in pairs(v) do @@ -154,15 +160,17 @@ function map_store:get(username, key) WHERE "host"=? AND "user"=? AND "store"=? AND "key"=? LIMIT 1 ]]; - local data; + local data, err; if type(key) == "string" and key ~= "" then for row in engine:select(query, host, username or "", self.store, key) do - data = deserialize(row[1], row[2]); + data, err = deserialize(row[1], row[2]); + assert(data ~= nil, err); end return data; else for row in engine:select(query, host, username or "", self.store, "") do - data = deserialize(row[1], row[2]); + data, err = deserialize(row[1], row[2]); + assert(data ~= nil, err); end return data and data[key] or nil; end @@ -200,9 +208,10 @@ function map_store:set_keys(username, keydatas) engine:insert(insert_sql, host, username or "", self.store, key, t, value); end else - local extradata = {}; + local extradata, err = {}; for row in engine:select(select_extradata_sql, host, username or "", self.store, "") do - extradata = deserialize(row[1], row[2]); + extradata, err = deserialize(row[1], row[2]); + assert(extradata ~= nil, err); end engine:delete(delete_sql, host, username or "", self.store, ""); extradata[key] = data; @@ -356,7 +365,9 @@ function archive_store:find(username, query) return function() local row = result(); if row ~= nil then - return row[1], deserialize(row[2], row[3]), row[4], row[5]; + local value, err = deserialize(row[2], row[3]); + assert(value ~= nil, err); + return row[1], value, row[4], row[5]; end end, total; end @@ -374,7 +385,35 @@ function archive_store:delete(username, query) end archive_where(query, args, where); archive_where_id_range(query, args, where); - sql_query = sql_query:format(t_concat(where, " AND ")); + if query.truncate == nil then + sql_query = sql_query:format(t_concat(where, " AND ")); + else + args[#args+1] = query.truncate; + local unlimited = "ALL"; + if engine.params.driver == "SQLite3" then + sql_query = [[ + DELETE FROM "prosodyarchive" + WHERE %s + ORDER BY "sort_id" %s + LIMIT %s OFFSET ?; + ]]; + unlimited = "-1"; + else + sql_query = [[ + DELETE FROM "prosodyarchive" + WHERE "sort_id" IN ( + SELECT "sort_id" FROM "prosodyarchive" + WHERE %s + ORDER BY "sort_id" %s + LIMIT %s OFFSET ? + );]]; + if engine.params.driver == "MySQL" then + unlimited = "18446744073709551615"; + end + end + sql_query = string.format(sql_query, t_concat(where, " AND "), + query.reverse and "ASC" or "DESC", unlimited); + end return engine:delete(sql_query, unpack(args)); end); return ok and stmt:affected(), stmt; @@ -427,7 +466,7 @@ local function create_table(engine, name) -- luacheck: ignore 431/engine local Table, Column, Index = sql.Table, sql.Column, sql.Index; local ProsodyTable = Table { - name= name or "prosody"; + name = "prosody"; Column { name="host", type="TEXT", nullable=false }; Column { name="user", type="TEXT", nullable=false }; Column { name="store", type="TEXT", nullable=false }; @@ -477,7 +516,7 @@ local function upgrade_table(engine, params, apply_changes) -- luacheck: ignore end); if not success then module:log("error", "Failed to check/upgrade database schema (%s), please see " - .."http://prosody.im/doc/mysql for help", + .."https://prosody.im/doc/mysql for help", err or "unknown error"); return false; end diff --git a/plugins/mod_storage_sql1.lua b/plugins/mod_storage_sql1.lua index a5bb5bfa..e1041bca 100644 --- a/plugins/mod_storage_sql1.lua +++ b/plugins/mod_storage_sql1.lua @@ -130,7 +130,7 @@ local function create_table() module:log("info", "Database table automatically upgraded"); else module:log("error", "Failed to upgrade database schema (%s), please see " - .."http://prosody.im/doc/mysql for help", + .."https://prosody.im/doc/mysql for help", err or "unknown error"); end end @@ -139,7 +139,7 @@ local function create_table() end elseif params.driver ~= "SQLite3" then -- SQLite normally fails to prepare for existing table module:log("warn", "Prosody was not able to automatically check/create the database table (%s), " - .."see http://prosody.im/doc/modules/mod_storage_sql#table_management for help.", + .."see https://prosody.im/doc/modules/mod_storage_sql#table_management for help.", err or "unknown error"); end end @@ -151,7 +151,7 @@ do -- process options to get a db connection if not ok then package.loaded["DBI"] = {}; module:log("error", "Failed to load the LuaDBI library for accessing SQL databases: %s", DBI); - module:log("error", "More information on installing LuaDBI can be found at http://prosody.im/doc/depends#luadbi"); + module:log("error", "More information on installing LuaDBI can be found at https://prosody.im/doc/depends#luadbi"); end prosody.lock_globals(); if not ok or not DBI.Connect then diff --git a/plugins/mod_storage_xep0227.lua b/plugins/mod_storage_xep0227.lua index ef227ca3..229ad6b5 100644 --- a/plugins/mod_storage_xep0227.lua +++ b/plugins/mod_storage_xep0227.lua @@ -164,10 +164,84 @@ handlers.private = { end; }; +handlers.roster = { + get = function(self, user) + user = getUserElement(getXml(user, self.host)); + if user then + local roster = user:get_child("query", "jabber:iq:roster"); + if roster then + local r = { + [false] = { + version = roster.attr.version; + pending = {}; + } + }; + for item in roster:childtags("item") do + r[item.attr.jid] = { + jid = item.attr.jid, + subscription = item.attr.subscription, + ask = item.attr.ask, + name = item.attr.name, + groups = {}; + }; + for group in item:childtags("group") do + r[item.attr.jid].groups[group:get_text()] = true; + end + for pending in user:childtags("presence", "jabber:client") do + r[false].pending[pending.attr.from] = true; + end + end + return r; + end + end + end; + set = function(self, user, data) + local xml = getXml(user, self.host); + local usere = xml and getUserElement(xml); + if usere then + local roster = usere:get_child("query", 'jabber:iq:roster'); + if roster then removeStanzaChild(usere, roster); end + usere:maptags(function (tag) + if tag.attr.xmlns == "jabber:client" and tag.name == "presence" and tag.attr.type == "subscribe" then + return nil; + end + return tag; + end); + if data and next(data) ~= nil then + roster = st.stanza("query", {xmlns='jabber:iq:roster'}); + usere:add_child(roster); + for jid, item in pairs(data) do + if jid then + roster:tag("item", { + jid = jid, + subscription = item.subscription, + ask = item.ask, + name = item.name, + }); + for group in pairs(item.groups) do + roster:tag("group"):text(group):up(); + end + roster:up(); -- move out from item + else + roster.attr.version = item.version; + for pending_jid in pairs(item.pending) do + usere:add_child(st.presence({ from = pending_jid, type = "subscribe" })); + end + end + end + end + return setXml(user, self.host, xml); + end + return true; + end; +}; + + ----------------------------- local driver = {}; -function driver:open(datastore, typ) +function driver:open(datastore, typ) -- luacheck: ignore 212/self + if typ and typ ~= "keyval" then return nil, "unsupported-store"; end local handler = handlers[datastore]; if not handler then return nil, "unsupported-datastore"; end local instance = setmetatable({ host = module.host; datastore = datastore; }, { __index = handler }); diff --git a/plugins/mod_uptime.lua b/plugins/mod_uptime.lua index 2e369b16..a0a844a1 100644 --- a/plugins/mod_uptime.lua +++ b/plugins/mod_uptime.lua @@ -39,7 +39,7 @@ function uptime_text() minutes, (minutes ~= 1 and "s") or "", os.date("%c", prosody.start_time)); end -function uptime_command_handler (self, data, state) +function uptime_command_handler () return { info = uptime_text(), status = "completed" }; end diff --git a/plugins/mod_user_account_management.lua b/plugins/mod_user_account_management.lua new file mode 100644 index 00000000..615c1ed6 --- /dev/null +++ b/plugins/mod_user_account_management.lua @@ -0,0 +1,86 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + + +local st = require "util.stanza"; +local usermanager_set_password = require "core.usermanager".set_password; +local usermanager_delete_user = require "core.usermanager".delete_user; +local nodeprep = require "util.encodings".stringprep.nodeprep; +local jid_bare = require "util.jid".bare; + +local compat = module:get_option_boolean("registration_compat", true); + +module:add_feature("jabber:iq:register"); + +-- Password change and account deletion handler +local function handle_registration_stanza(event) + local session, stanza = event.origin, event.stanza; + local log = session.log or module._log; + + local query = stanza.tags[1]; + if stanza.attr.type == "get" then + local reply = st.reply(stanza); + reply:tag("query", {xmlns = "jabber:iq:register"}) + :tag("registered"):up() + :tag("username"):text(session.username):up() + :tag("password"):up(); + session.send(reply); + else -- stanza.attr.type == "set" + if query.tags[1] and query.tags[1].name == "remove" then + local username, host = session.username, session.host; + + -- This one weird trick sends a reply to this stanza before the user is deleted + local old_session_close = session.close; + session.close = function(self, ...) + self.send(st.reply(stanza)); + return old_session_close(self, ...); + end + + local ok, err = usermanager_delete_user(username, host); + + if not ok then + log("debug", "Removing user account %s@%s failed: %s", username, host, err); + session.close = old_session_close; + session.send(st.error_reply(stanza, "cancel", "service-unavailable", err)); + return true; + end + + log("info", "User removed their account: %s@%s", username, host); + module:fire_event("user-deregistered", { username = username, host = host, source = "mod_register", session = session }); + else + local username = nodeprep(query:get_child_text("username")); + local password = query:get_child_text("password"); + if username and password then + if username == session.username then + if usermanager_set_password(username, password, session.host, session.resource) then + session.send(st.reply(stanza)); + else + -- TODO unable to write file, file may be locked, etc, what's the correct error? + session.send(st.error_reply(stanza, "wait", "internal-server-error")); + end + else + session.send(st.error_reply(stanza, "modify", "bad-request")); + end + else + session.send(st.error_reply(stanza, "modify", "bad-request")); + end + end + end + return true; +end + +module:hook("iq/self/jabber:iq:register:query", handle_registration_stanza); +if compat then + module:hook("iq/host/jabber:iq:register:query", function (event) + local session, stanza = event.origin, event.stanza; + if session.type == "c2s" and jid_bare(stanza.attr.to) == session.host then + return handle_registration_stanza(event); + end + end); +end + diff --git a/plugins/mod_watchregistrations.lua b/plugins/mod_watchregistrations.lua index 82666b09..825b8a73 100644 --- a/plugins/mod_watchregistrations.lua +++ b/plugins/mod_watchregistrations.lua @@ -13,12 +13,13 @@ local jid_prep = require "util.jid".prep; local registration_watchers = module:get_option_set("registration_watchers", module:get_option("admins", {})) / jid_prep; local registration_from = module:get_option_string("registration_from", host); local registration_notification = module:get_option_string("registration_notification", "User $username just registered on $host from $ip"); +local msg_type = module:get_option_string("registration_notification_type", "chat"); local st = require "util.stanza"; module:hook("user-registered", function (user) module:log("debug", "Notifying of new registration"); - local message = st.message{ type = "chat", from = registration_from } + local message = st.message{ type = msg_type, from = registration_from } :tag("body") :text(registration_notification:gsub("%$(%w+)", function (v) return user[v] or user.session and user.session[v] or nil; diff --git a/plugins/mod_websocket.lua b/plugins/mod_websocket.lua index edc104df..d301088e 100644 --- a/plugins/mod_websocket.lua +++ b/plugins/mod_websocket.lua @@ -256,6 +256,10 @@ function handle_request(event) local session = sessions[conn]; + -- Use upstream IP if a HTTP proxy was used + -- See mod_http and #540 + session.ip = request.ip; + session.secure = consider_websocket_secure or session.secure; session.websocket_request = request; diff --git a/plugins/muc/description.lib.lua b/plugins/muc/description.lib.lua new file mode 100644 index 00000000..d7e3f7c6 --- /dev/null +++ b/plugins/muc/description.lib.lua @@ -0,0 +1,51 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local function get_description(room) + return room._data.description; +end + +local function set_description(room, description) + if description == "" then description = nil; end + if get_description(room) == description then return false; end + room._data.description = description; + return true; +end + +local function add_disco_form(event) + table.insert(event.form, { + name = "muc#roominfo_description"; + label = "Description"; + value = ""; + }); + event.formdata["muc#roominfo_description"] = get_description(event.room); +end + +local function add_form_option(event) + table.insert(event.form, { + name = "muc#roomconfig_roomdesc"; + type = "text-single"; + label = "Description"; + value = get_description(event.room) or ""; + }); +end + +module:hook("muc-disco#info", add_disco_form); +module:hook("muc-config-form", add_form_option, 100-2); + +module:hook("muc-config-submitted/muc#roomconfig_roomdesc", function(event) + if set_description(event.room, event.value) then + event.status_codes["104"] = true; + end +end); + +return { + get = get_description; + set = set_description; +}; diff --git a/plugins/muc/hidden.lib.lua b/plugins/muc/hidden.lib.lua new file mode 100644 index 00000000..b2fe6216 --- /dev/null +++ b/plugins/muc/hidden.lib.lua @@ -0,0 +1,43 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local function get_hidden(room) + return room._data.hidden; +end + +local function set_hidden(room, hidden) + hidden = hidden and true or nil; + if get_hidden(room) == hidden then return false; end + room._data.hidden = hidden; + return true; +end + +module:hook("muc-config-form", function(event) + table.insert(event.form, { + name = "muc#roomconfig_publicroom"; + type = "boolean"; + label = "Make Room Publicly Searchable?"; + value = not get_hidden(event.room); + }); +end, 100-5); + +module:hook("muc-config-submitted/muc#roomconfig_publicroom", function(event) + if set_hidden(event.room, not event.value) then + event.status_codes["104"] = true; + end +end); + +module:hook("muc-disco#info", function(event) + event.reply:tag("feature", {var = get_hidden(event.room) and "muc_hidden" or "muc_public"}):up(); +end); + +return { + get = get_hidden; + set = set_hidden; +}; diff --git a/plugins/muc/history.lib.lua b/plugins/muc/history.lib.lua new file mode 100644 index 00000000..100ab720 --- /dev/null +++ b/plugins/muc/history.lib.lua @@ -0,0 +1,204 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local gettime = os.time; +local datetime = require "util.datetime"; +local st = require "util.stanza"; + +local default_history_length = 20; +local max_history_length = module:get_option_number("max_history_messages", math.huge); + +local function set_max_history_length(_max_history_length) + max_history_length = _max_history_length or math.huge; +end + +local function get_historylength(room) + return math.min(room._data.history_length or default_history_length, max_history_length); +end + +local function set_historylength(room, length) + if length then + length = assert(tonumber(length), "Length not a valid number"); + end + if length == default_history_length then length = nil; end + room._data.history_length = length; + return true; +end + +-- Fix for clients who don't support XEP-0045 correctly +-- Default number of history messages the room returns +local function get_defaulthistorymessages(room) + return room._data.default_history_messages or default_history_length; +end +local function set_defaulthistorymessages(room, number) + number = math.min(tonumber(number) or default_history_length, room._data.history_length or default_history_length); + if number == default_history_length then + number = nil; + end + room._data.default_history_messages = number; +end + +module:hook("muc-config-form", function(event) + table.insert(event.form, { + name = "muc#roomconfig_historylength"; + type = "text-single"; + label = "Maximum Number of History Messages Returned by Room"; + value = tostring(get_historylength(event.room)); + }); + table.insert(event.form, { + name = 'muc#roomconfig_defaulthistorymessages', + type = 'text-single', + label = 'Default Number of History Messages Returned by Room', + value = tostring(get_defaulthistorymessages(event.room)) + }); +end, 100-10); + +module:hook("muc-config-submitted/muc#roomconfig_historylength", function(event) + if set_historylength(event.room, event.value) then + event.status_codes["104"] = true; + end +end); + +module:hook("muc-config-submitted/muc#roomconfig_defaulthistorymessages", function(event) + if set_defaulthistorymessages(event.room, event.value) then + event.status_codes["104"] = true; + end +end); + +local function parse_history(stanza) + local x_tag = stanza:get_child("x", "http://jabber.org/protocol/muc"); + local history_tag = x_tag and x_tag:get_child("history", "http://jabber.org/protocol/muc"); + if not history_tag then + return nil, nil, nil; + end + + local maxchars = tonumber(history_tag.attr.maxchars); + + local maxstanzas = tonumber(history_tag.attr.maxstanzas); + + -- messages received since the UTC datetime specified + local since = history_tag.attr.since; + if since then + since = datetime.parse(since); + end + + -- messages received in the last "X" seconds. + local seconds = tonumber(history_tag.attr.seconds); + if seconds then + seconds = gettime() - seconds; + if since then + since = math.max(since, seconds); + else + since = seconds; + end + end + + return maxchars, maxstanzas, since; +end + +module:hook("muc-get-history", function(event) + local room = event.room; + local history = room._history; -- send discussion history + if not history then return nil end + local history_len = #history; + + local to = event.to; + local maxchars = event.maxchars; + local maxstanzas = event.maxstanzas or history_len; + local since = event.since; + local n = 0; + local charcount = 0; + for i=history_len,1,-1 do + local entry = history[i]; + if maxchars then + if not entry.chars then + entry.stanza.attr.to = ""; + entry.chars = #tostring(entry.stanza); + end + charcount = charcount + entry.chars + #to; + if charcount > maxchars then break; end + end + if since and since > entry.timestamp then break; end + if n + 1 > maxstanzas then break; end + n = n + 1; + end + + local i = history_len-n+1 + function event.next_stanza() + if i > history_len then return nil end + local entry = history[i]; + local msg = entry.stanza; + msg.attr.to = to; + i = i + 1; + return msg; + end + return true; +end, -1); + +local function send_history(room, stanza) + local maxchars, maxstanzas, since = parse_history(stanza); + if not(maxchars or maxstanzas or since) then + maxstanzas = get_defaulthistorymessages(room); + end + local event = { + room = room; + stanza = stanza; + to = stanza.attr.from; -- `to` is required to calculate the character count for `maxchars` + maxchars = maxchars, + maxstanzas = maxstanzas, + since = since; + next_stanza = function() end; -- events should define this iterator + }; + module:fire_event("muc-get-history", event); + for msg in event.next_stanza, event do + room:route_stanza(msg); + end +end + +-- Send history on join +module:hook("muc-occupant-session-new", function(event) + send_history(event.room, event.stanza); +end, 50); -- Before subject(20) + +-- add to history +module:hook("muc-add-history", function(event) + local room = event.room + local history = room._history; + if not history then history = {}; room._history = history; end + local stanza = st.clone(event.stanza); + stanza.attr.to = ""; + local ts = gettime(); + local stamp = datetime.datetime(ts); + stanza:tag("delay", {xmlns = "urn:xmpp:delay", from = module.host, stamp = stamp}):up(); -- XEP-0203 + stanza:tag("x", {xmlns = "jabber:x:delay", from = module.host, stamp = datetime.legacy()}):up(); -- XEP-0091 (deprecated) + local entry = { stanza = stanza, timestamp = ts }; + table.insert(history, entry); + while #history > get_historylength(room) do table.remove(history, 1) end + return true; +end, -1); + +-- Have a single muc-add-history event, so that plugins can mark it +-- as handled without stopping other muc-broadcast-message handlers +module:hook("muc-broadcast-message", function(event) + if module:fire_event("muc-message-is-historic", event) then + module:fire_event("muc-add-history", event); + end +end); + +module:hook("muc-message-is-historic", function (event) + return event.stanza:get_child("body"); +end, -1); + +return { + set_max_length = set_max_history_length; + parse_history = parse_history; + send = send_history; + get_length = get_historylength; + set_length = set_historylength; +}; diff --git a/plugins/muc/language.lib.lua b/plugins/muc/language.lib.lua new file mode 100644 index 00000000..ae9bcfed --- /dev/null +++ b/plugins/muc/language.lib.lua @@ -0,0 +1,50 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local function get_language(room) + return room._data.language; +end + +local function set_language(room, language) + if language == "" then language = nil; end + if get_language(room) == language then return false; end + room._data.language = language; + return true; +end + +local function add_disco_form(event) + table.insert(event.form, { + name = "muc#roominfo_lang"; + value = ""; + }); + event.formdata["muc#roominfo_lang"] = get_language(event.room); +end + +local function add_form_option(event) + table.insert(event.form, { + name = "muc#roomconfig_lang"; + label = "Language tag for Room (e.g. 'en', 'de', 'fr' etc.)"; + type = "text-single"; + value = get_language(event.room) or ""; + }); +end + +module:hook("muc-disco#info", add_disco_form); +module:hook("muc-config-form", add_form_option, 100-9.5); + +module:hook("muc-config-submitted/muc#roomconfig_lang", function(event) + if set_language(event.room, event.value) then + event.status_codes["104"] = true; + end +end); + +return { + get = get_language; + set = set_language; +}; diff --git a/plugins/muc/lock.lib.lua b/plugins/muc/lock.lib.lua new file mode 100644 index 00000000..062ab615 --- /dev/null +++ b/plugins/muc/lock.lib.lua @@ -0,0 +1,62 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local st = require "util.stanza"; + +local lock_rooms = module:get_option_boolean("muc_room_locking", true); +local lock_room_timeout = module:get_option_number("muc_room_lock_timeout", 300); + +local function lock(room) + module:fire_event("muc-room-locked", {room = room;}); + room._data.locked = os.time() + lock_room_timeout; +end +local function unlock(room) + module:fire_event("muc-room-unlocked", {room = room;}); + room._data.locked = nil; +end +local function is_locked(room) + local ts = room._data.locked; + if ts then + if os.time() < ts then return true; end + unlock(room); + end + return false; +end + +if lock_rooms then + module:hook("muc-room-pre-create", function(event) + -- Older groupchat protocol doesn't lock + if not event.stanza:get_child("x", "http://jabber.org/protocol/muc") then return end + -- Lock room at creation + local room = event.room; + lock(room); + end, 10); +end + +-- Don't let users into room while it is locked +module:hook("muc-occupant-pre-join", function(event) + if not event.is_new_room and is_locked(event.room) then -- Deny entry + module:log("debug", "Room is locked, denying entry"); + event.origin.send(st.error_reply(event.stanza, "cancel", "item-not-found")); + return true; + end +end, -30); + +-- When config is submitted; unlock the room +module:hook("muc-config-submitted", function(event) + if is_locked(event.room) then + unlock(event.room); + end +end, -1); + +return { + lock = lock; + unlock = unlock; + is_locked = is_locked; +}; diff --git a/plugins/muc/members_only.lib.lua b/plugins/muc/members_only.lib.lua new file mode 100644 index 00000000..1e5e6a56 --- /dev/null +++ b/plugins/muc/members_only.lib.lua @@ -0,0 +1,128 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local st = require "util.stanza"; + +local muc_util = module:require "muc/util"; +local valid_affiliations = muc_util.valid_affiliations; + +local function get_members_only(room) + return room._data.members_only; +end + +local function set_members_only(room, members_only) + members_only = members_only and true or nil; + if room._data.members_only == members_only then return false; end + room._data.members_only = members_only; + if members_only then + --[[ + If as a result of a change in the room configuration the room type is + changed to members-only but there are non-members in the room, + the service MUST remove any non-members from the room and include a + status code of 322 in the presence unavailable stanzas sent to those users + as well as any remaining occupants. + ]] + local occupants_changed = {}; + for _, occupant in room:each_occupant() do + local affiliation = room:get_affiliation(occupant.bare_jid); + if valid_affiliations[affiliation or "none"] <= valid_affiliations.none then + occupant.role = nil; + room:save_occupant(occupant); + occupants_changed[occupant] = true; + end + end + local x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user"}) + :tag("status", {code="322"}):up(); + for occupant in pairs(occupants_changed) do + room:publicise_occupant_status(occupant, x); + module:fire_event("muc-occupant-left", {room = room; nick = occupant.nick; occupant = occupant;}); + end + end + return true; +end + +module:hook("muc-disco#info", function(event) + event.reply:tag("feature", {var = get_members_only(event.room) and "muc_membersonly" or "muc_open"}):up(); +end); + +module:hook("muc-config-form", function(event) + table.insert(event.form, { + name = "muc#roomconfig_membersonly"; + type = "boolean"; + label = "Make Room Members-Only?"; + value = get_members_only(event.room); + }); +end, 100-6); + +module:hook("muc-config-submitted/muc#roomconfig_membersonly", function(event) + if set_members_only(event.room, event.value) then + event.status_codes["104"] = true; + end +end); + +-- No affiliation => role of "none" +module:hook("muc-get-default-role", function(event) + if not event.affiliation and get_members_only(event.room) then + return false; + end +end); + +-- registration required for entering members-only room +module:hook("muc-occupant-pre-join", function(event) + local room = event.room; + if get_members_only(room) then + local stanza = event.stanza; + local affiliation = room:get_affiliation(stanza.attr.from); + if valid_affiliations[affiliation or "none"] <= valid_affiliations.none then + local reply = st.error_reply(stanza, "auth", "registration-required"):up(); + reply.tags[1].attr.code = "407"; + event.origin.send(reply:tag("x", {xmlns = "http://jabber.org/protocol/muc"})); + return true; + end + end +end, -5); + +-- Invitation privileges in members-only rooms SHOULD be restricted to room admins; +-- if a member without privileges to edit the member list attempts to invite another user +-- the service SHOULD return a <forbidden/> error to the occupant +module:hook("muc-pre-invite", function(event) + local room = event.room; + if get_members_only(room) then + local stanza = event.stanza; + local affiliation = room:get_affiliation(stanza.attr.from); + if valid_affiliations[affiliation or "none"] < valid_affiliations.admin then + event.origin.send(st.error_reply(stanza, "auth", "forbidden")); + return true; + end + end +end); + +-- When an invite is sent; add an affiliation for the invitee +module:hook("muc-invite", function(event) + local room = event.room; + if get_members_only(room) then + local stanza = event.stanza; + local invitee = stanza.attr.to; + local affiliation = room:get_affiliation(invitee); + if valid_affiliations[affiliation or "none"] <= valid_affiliations.none then + local from = stanza:get_child("x", "http://jabber.org/protocol/muc#user") + :get_child("invite").attr.from; + module:log("debug", "%s invited %s into members only room %s, granting membership", + from, invitee, room.jid); + -- This might fail; ignore for now + room:set_affiliation(from, invitee, "member", "Invited by " .. from); + room:save(); + end + end +end); + +return { + get = get_members_only; + set = set_members_only; +}; diff --git a/plugins/muc/mod_muc.lua b/plugins/muc/mod_muc.lua index 0f58bfbc..c3975282 100644 --- a/plugins/muc/mod_muc.lua +++ b/plugins/muc/mod_muc.lua @@ -6,288 +6,416 @@ -- COPYING file in the source package for more information. -- -local array = require "util.array"; +-- Exposed functions: +-- +-- create_room(jid) -> room +-- track_room(room) +-- delete_room(room) +-- forget_room(room) +-- get_room_from_jid(jid) -> room +-- each_room(local_only) -> () -> room +-- shutdown_component() if module:get_host_type() ~= "component" then - error("MUC should be loaded as a component, please see http://prosody.im/doc/components", 0); + error("MUC should be loaded as a component, please see https://prosody.im/doc/components", 0); end -local muc_host = module:get_host(); -local muc_name = module:get_option_string("name", "Prosody Chatrooms"); -local restrict_room_creation = module:get_option("restrict_room_creation"); -if restrict_room_creation then - if restrict_room_creation == true then - restrict_room_creation = "admin"; - elseif restrict_room_creation ~= "admin" and restrict_room_creation ~= "local" then - restrict_room_creation = nil; - end +local muclib = module:require "muc"; +room_mt = muclib.room_mt; -- Yes, global. +new_room = muclib.new_room; + +local name = module:require "muc/name"; +room_mt.get_name = name.get; +room_mt.set_name = name.set; + +local description = module:require "muc/description"; +room_mt.get_description = description.get; +room_mt.set_description = description.set; + +local language = module:require "muc/language"; +room_mt.get_language = language.get; +room_mt.set_language = language.set; + +local hidden = module:require "muc/hidden"; +room_mt.get_hidden = hidden.get; +room_mt.set_hidden = hidden.set; +function room_mt:get_public() + return not self:get_hidden(); +end +function room_mt:set_public(public) + return self:set_hidden(not public); end -local lock_rooms = module:get_option_boolean("muc_room_locking", false); -local lock_room_timeout = module:get_option_number("muc_room_lock_timeout", 300); -local muclib = module:require "muc"; -local muc_new_room = muclib.new_room; +local password = module:require "muc/password"; +room_mt.get_password = password.get; +room_mt.set_password = password.set; + +local members_only = module:require "muc/members_only"; +room_mt.get_members_only = members_only.get; +room_mt.set_members_only = members_only.set; + +local moderated = module:require "muc/moderated"; +room_mt.get_moderated = moderated.get; +room_mt.set_moderated = moderated.set; + +local request = module:require "muc/request"; +room_mt.handle_role_request = request.handle_request; + +local persistent = module:require "muc/persistent"; +room_mt.get_persistent = persistent.get; +room_mt.set_persistent = persistent.set; + +local subject = module:require "muc/subject"; +room_mt.get_changesubject = subject.get_changesubject; +room_mt.set_changesubject = subject.set_changesubject; +room_mt.get_subject = subject.get; +room_mt.set_subject = subject.set; +room_mt.send_subject = subject.send; + +local history = module:require "muc/history"; +room_mt.send_history = history.send; +room_mt.get_historylength = history.get_length; +room_mt.set_historylength = history.set_length; + local jid_split = require "util.jid".split; local jid_bare = require "util.jid".bare; local st = require "util.stanza"; -local uuid_gen = require "util.uuid".generate; +local cache = require "util.cache"; local um_is_admin = require "core.usermanager".is_admin; -local hosts = prosody.hosts; - -rooms = {}; -local rooms = rooms; -local persistent_rooms_storage = module:open_store("persistent"); -local persistent_rooms, err = persistent_rooms_storage:get(); -if not persistent_rooms then - if err then - module:log("error", "Error loading list of persistent rooms from storage. Reload mod_muc or restart to recover."); - error("Storage error: "..err); - end - module:log("debug", "No persistent rooms found in the database"); - persistent_rooms = {}; -end -local room_configs = module:open_store("config"); - --- Configurable options -muclib.set_max_history_length(module:get_option_number("max_history_messages")); module:depends("disco"); -module:add_identity("conference", "text", muc_name); +module:add_identity("conference", "text", module:get_option_string("name", "Prosody Chatrooms")); module:add_feature("http://jabber.org/protocol/muc"); +module:depends "muc_unique" +module:require "muc/lock"; local function is_admin(jid) return um_is_admin(jid, module.host); end -room_mt = muclib.room_mt; -- Yes, global. -local _set_affiliation = room_mt.set_affiliation; -local _get_affiliation = room_mt.get_affiliation; -function muclib.room_mt:get_affiliation(jid) - if is_admin(jid) then return "owner"; end - return _get_affiliation(self, jid); -end -function muclib.room_mt:set_affiliation(actor, jid, affiliation, callback, reason) - if affiliation ~= "owner" and is_admin(jid) then return nil, "modify", "not-acceptable"; end - return _set_affiliation(self, actor, jid, affiliation, callback, reason); +do -- Monkey patch to make server admins room owners + local _get_affiliation = room_mt.get_affiliation; + function room_mt:get_affiliation(jid) + if is_admin(jid) then return "owner"; end + return _get_affiliation(self, jid); + end + + local _set_affiliation = room_mt.set_affiliation; + function room_mt:set_affiliation(actor, jid, affiliation, reason) + if affiliation ~= "owner" and is_admin(jid) then return nil, "modify", "not-acceptable"; end + return _set_affiliation(self, actor, jid, affiliation, reason); + end end -local function room_route_stanza(room, stanza) module:send(stanza); end -local function room_save(room, forced) +local persistent_rooms_storage = module:open_store("persistent"); +local persistent_rooms = module:open_store("persistent", "map"); +local room_configs = module:open_store("config"); +local room_state = module:open_store("state"); + +local room_items_cache = {}; + +local function room_save(room, forced, savestate) local node = jid_split(room.jid); - persistent_rooms[room.jid] = room._data.persistent; - if room._data.persistent then - local history = room._data.history; - room._data.history = nil; - local data = { - jid = room.jid; - _data = room._data; - _affiliations = room._affiliations; - }; - room_configs:set(node, data); - room._data.history = history; + local is_persistent = persistent.get(room); + room_items_cache[room.jid] = room:get_public() and room:get_name() or nil; + if is_persistent or savestate then + persistent_rooms:set(nil, room.jid, true); + local data, state = room:freeze(savestate); + room_state:set(node, state); + return room_configs:set(node, data); elseif forced then - room_configs:set(node, nil); - if not next(room._occupants) then -- Room empty - rooms[room.jid] = nil; - end + persistent_rooms:set(nil, room.jid, nil); + room_state:set(node, nil); + return room_configs:set(node, nil); end - if forced then persistent_rooms_storage:set(nil, persistent_rooms); end end -function create_room(jid, locked) - local room = muc_new_room(jid); - room.route_stanza = room_route_stanza; - room.save = room_save; - rooms[jid] = room; - if locked then - room.locked = true; - if lock_room_timeout and lock_room_timeout > 0 then - module:add_timer(lock_room_timeout, function () - if room.locked then - room:destroy(); -- Not unlocked in time - end - end); - end +local max_rooms = module:get_option_number("muc_max_rooms"); +local max_live_rooms = module:get_option_number("muc_room_cache_size", 100); + +local room_hit = module:measure("room_hit", "rate"); +local room_miss = module:measure("room_miss", "rate") +local room_eviction = module:measure("room_eviction", "rate"); +local rooms = cache.new(max_rooms or max_live_rooms, function (jid, room) + if max_rooms then + module:log("info", "Room limit of %d reached, no new rooms allowed"); + return false; + end + module:log("debug", "Evicting room %s", jid); + room_eviction(); + room_items_cache[room.jid] = room:get_public() and room:get_name() or nil; + local ok, err = room_save(room, nil, true); -- Force to disk + if not ok then + module:log("error", "Failed to swap inactive room %s to disk: %s", jid, err); + return false; + end +end); + +-- Automatically destroy empty non-persistent rooms +module:hook("muc-occupant-left",function(event) + local room = event.room + if not room:has_occupant() and not persistent.get(room) then -- empty, non-persistent room + module:fire_event("muc-room-destroyed", { room = room }); end - module:fire_event("muc-room-created", { room = room }); - return room; +end, -1); + +function track_room(room) + if rooms:set(room.jid, room) then + -- When room is created, over-ride 'save' method + room.save = room_save; + return room; + end + return false; end -local persistent_errors = false; -for jid in pairs(persistent_rooms) do +local function handle_broken_room(room, origin, stanza) + module:log("debug", "Returning error from broken room %s", room.jid); + origin.send(st.error_reply(stanza, "wait", "internal-server-error")); + return true; +end + +local function restore_room(jid) local node = jid_split(jid); local data, err = room_configs:get(node); if data then - local room = create_room(jid); - room._data = data._data; - room._affiliations = data._affiliations; - elseif not err then -- missing room data - persistent_rooms[jid] = nil; - module:log("error", "Missing data for room '%s', removing from persistent room list", jid); - persistent_errors = true; - else -- error - module:log("error", "Error loading data for room '%s', locking it until service restart. Error was: %s", jid, err); - local room = muc_new_room(jid); - room.locked = true; - room._affiliations = { [muc_host] = "owner" }; -- To prevent unlocking - rooms[jid] = room; + module:log("debug", "Restoring room %s from storage", jid); + local state, s_err = room_state:get(node); + if not state and s_err then + module:log("debug", "Could not restore state of room %s: %s", jid, s_err); + end + local room = muclib.restore_room(data, state); + return track_room(room); + elseif err then + module:log("error", "Error restoring room %s from storage: %s", jid, err); + local room = muclib.new_room(jid, { locked = math.huge }); + room.handle_normal_presence = handle_broken_room; + room.handle_first_presence = handle_broken_room; + return room; end end -if persistent_errors then persistent_rooms_storage:set(nil, persistent_rooms); end -local host_room = muc_new_room(muc_host); -host_room.route_stanza = room_route_stanza; -host_room.save = room_save; +function forget_room(room) + module:log("debug", "Forgetting %s", room.jid); + rooms.save = nil; + rooms:set(room.jid, nil); +end -module:hook("host-disco-items", function(event) - local reply = event.reply; - module:log("debug", "host-disco-items called"); - for jid, room in pairs(rooms) do - if not room:get_hidden() then - reply:tag("item", {jid=jid, name=room:get_name()}):up(); - end +function delete_room(room) + module:log("debug", "Deleting %s", room.jid); + room_configs:set(jid_split(room.jid), nil); + persistent_rooms:set(nil, room.jid, nil); + room_items_cache[room.jid] = nil; +end + +function module.unload() + for room in rooms:values() do + room:save(nil, true); + forget_room(room); end -end); +end -local function handle_to_domain(event) - local origin, stanza = event.origin, event.stanza; - local type = stanza.attr.type; - if type == "error" or type == "result" then return; end - if stanza.name == "iq" and type == "get" then - local xmlns = stanza.tags[1].attr.xmlns; - local node = stanza.tags[1].attr.node; - if xmlns == "http://jabber.org/protocol/muc#unique" then - origin.send(st.reply(stanza):tag("unique", {xmlns = xmlns}):text(uuid_gen())); -- FIXME Random UUIDs can theoretically have collisions - else - origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); -- TODO disco/etc - end - else - host_room:handle_stanza(origin, stanza); - --origin.send(st.error_reply(stanza, "cancel", "service-unavailable", "The muc server doesn't deal with messages and presence directed at it")); +function get_room_from_jid(room_jid) + local room = rooms:get(room_jid); + if room then + room_hit(); + rooms:set(room_jid, room); -- bump to top; + return room; end - return true; + room_miss(); + return restore_room(room_jid); end -function stanza_handler(event) - local origin, stanza = event.origin, event.stanza; - local bare = jid_bare(stanza.attr.to); - local room = rooms[bare]; - if not room then - if stanza.name ~= "presence" or stanza.attr.type ~= nil then - if stanza.attr.type ~= "error" then - origin.send(st.error_reply(stanza, "cancel", "item-not-found")); +function create_room(room_jid, config) + local exists = get_room_from_jid(room_jid); + if exists then + return nil, "room-exists"; + end + local room = muclib.new_room(room_jid, config); + module:fire_event("muc-room-created", { + room = room; + }); + return track_room(room); +end + +function each_room(local_only) + if local_only then + return rooms:values(); + end + return coroutine.wrap(function () + local seen = {}; -- Don't iterate over persistent rooms twice + for room in rooms:values() do + coroutine.yield(room); + seen[room.jid] = true; + end + local all_persistent_rooms, err = persistent_rooms_storage:get(nil); + if not all_persistent_rooms then + if err then + module:log("error", "Error loading list of persistent rooms, only rooms live in memory were iterated over"); + module:log("debug", "%s", debug.traceback(err)); end - return true; + return nil; end - if not(restrict_room_creation) or - is_admin(stanza.attr.from) or - (restrict_room_creation == "local" and select(2, jid_split(stanza.attr.from)) == module.host:gsub("^[^%.]+%.", "")) then - room = create_room(bare, lock_rooms); + for room_jid in pairs(all_persistent_rooms) do + if not seen[room_jid] then + local room = restore_room(room_jid); + if room then + coroutine.yield(room); + else + module:log("error", "Missing data for room '%s', omitting from iteration", room_jid); + end + end end - end - if room then - room:handle_stanza(origin, stanza); - if not next(room._occupants) and not persistent_rooms[room.jid] then -- empty, non-persistent room - module:fire_event("muc-room-destroyed", { room = room }); - rooms[bare] = nil; -- discard room + end); +end + +module:hook("host-disco-items", function(event) + local reply = event.reply; + module:log("debug", "host-disco-items called"); + if next(room_items_cache) ~= nil then + for jid, room_name in pairs(room_items_cache) do + reply:tag("item", { jid = jid, name = room_name }):up(); end else - origin.send(st.error_reply(stanza, "cancel", "not-allowed")); + for room in each_room() do + if not room:get_hidden() then + local jid, room_name = room.jid, room:get_name(); + room_items_cache[jid] = room_name; + reply:tag("item", { jid = jid, name = room_name }):up(); + end + end end - return true; -end -module:hook("iq/bare", stanza_handler, -1); -module:hook("message/bare", stanza_handler, -1); -module:hook("presence/bare", stanza_handler, -1); -module:hook("iq/full", stanza_handler, -1); -module:hook("message/full", stanza_handler, -1); -module:hook("presence/full", stanza_handler, -1); -module:hook("iq/host", handle_to_domain, -1); -module:hook("message/host", handle_to_domain, -1); -module:hook("presence/host", handle_to_domain, -1); - -hosts[module.host].send = function(stanza) -- FIXME do a generic fix - if stanza.attr.type == "result" or stanza.attr.type == "error" then - module:send(stanza); - else error("component.send only supports result and error stanzas at the moment"); end -end +end); -hosts[module:get_host()].muc = { rooms = rooms }; +module:hook("muc-room-pre-create", function (event) + local room = event.room; + room:set_public(module:get_option_boolean("muc_room_default_public", false)); + room:set_persistent(module:get_option_boolean("muc_room_default_persistent", room:get_persistent())); + room:set_members_only(module:get_option_boolean("muc_room_default_members_only", room:get_members_only())); + room:set_moderated(module:get_option_boolean("muc_room_default_moderated", room:get_moderated())); + room:set_whois(module:get_option_boolean("muc_room_default_public_jids", room:get_whois() == "anyone") and "anyone" or "moderators"); + room:set_changesubject(module:get_option_boolean("muc_room_default_change_subject", room:get_changesubject())); + room:set_historylength(module:get_option_number("muc_room_default_history_length", room:get_historylength())); + room:set_language(event.stanza.attr["xml:lang"] or module:get_option_string("muc_room_default_language")); +end, 1); -local saved = false; -module.save = function() - saved = true; - return {rooms = rooms}; -end -module.restore = function(data) - for jid, oldroom in pairs(data.rooms or {}) do - local room = create_room(jid); - room._jid_nick = oldroom._jid_nick; - room._occupants = oldroom._occupants; - room._data = oldroom._data; - room._affiliations = oldroom._affiliations; +module:hook("muc-room-pre-create", function(event) + local origin, stanza = event.origin, event.stanza; + if not track_room(event.room) then + origin.send(st.error_reply(stanza, "wait", "resource-constraint")); + return true; + end +end, -1000); + +module:hook("muc-room-destroyed",function(event) + local room = event.room; + forget_room(room); + delete_room(room); +end); + +do + local restrict_room_creation = module:get_option("restrict_room_creation"); + if restrict_room_creation == true then + restrict_room_creation = "admin"; + end + if restrict_room_creation then + local host_suffix = module.host:gsub("^[^%.]+%.", ""); + module:hook("muc-room-pre-create", function(event) + local origin, stanza = event.origin, event.stanza; + local user_jid = stanza.attr.from; + if not is_admin(user_jid) and not ( + restrict_room_creation == "local" and + select(2, jid_split(user_jid)) == host_suffix + ) then + origin.send(st.error_reply(stanza, "cancel", "not-allowed", "Room creation is restricted")); + return true; + end + end); end - hosts[module:get_host()].muc = { rooms = rooms }; end -function shutdown_room(room, stanza) - for nick, occupant in pairs(room._occupants) do - stanza.attr.from = nick; - for jid in pairs(occupant.sessions) do - stanza.attr.to = jid; - room:_route_stanza(stanza); - room._jid_nick[jid] = nil; +for event_name, method in pairs { + -- Normal room interactions + ["iq-get/bare/http://jabber.org/protocol/disco#info:query"] = "handle_disco_info_get_query" ; + ["iq-get/bare/http://jabber.org/protocol/disco#items:query"] = "handle_disco_items_get_query" ; + ["iq-set/bare/http://jabber.org/protocol/muc#admin:query"] = "handle_admin_query_set_command" ; + ["iq-get/bare/http://jabber.org/protocol/muc#admin:query"] = "handle_admin_query_get_command" ; + ["iq-set/bare/http://jabber.org/protocol/muc#owner:query"] = "handle_owner_query_set_to_room" ; + ["iq-get/bare/http://jabber.org/protocol/muc#owner:query"] = "handle_owner_query_get_to_room" ; + ["message/bare"] = "handle_message_to_room" ; + ["presence/bare"] = "handle_presence_to_room" ; + -- Host room + ["iq-get/host/http://jabber.org/protocol/disco#info:query"] = "handle_disco_info_get_query" ; + ["iq-get/host/http://jabber.org/protocol/disco#items:query"] = "handle_disco_items_get_query" ; + ["iq-set/host/http://jabber.org/protocol/muc#admin:query"] = "handle_admin_query_set_command" ; + ["iq-get/host/http://jabber.org/protocol/muc#admin:query"] = "handle_admin_query_get_command" ; + ["iq-set/host/http://jabber.org/protocol/muc#owner:query"] = "handle_owner_query_set_to_room" ; + ["iq-get/host/http://jabber.org/protocol/muc#owner:query"] = "handle_owner_query_get_to_room" ; + ["message/host"] = "handle_message_to_room" ; + ["presence/host"] = "handle_presence_to_room" ; + -- Direct to occupant (normal rooms and host room) + ["presence/full"] = "handle_presence_to_occupant" ; + ["iq/full"] = "handle_iq_to_occupant" ; + ["message/full"] = "handle_message_to_occupant" ; +} do + module:hook(event_name, function (event) + local origin, stanza = event.origin, event.stanza; + local room_jid = jid_bare(stanza.attr.to); + local room = get_room_from_jid(room_jid); + if room == nil then + -- Watch presence to create rooms + if stanza.attr.type == nil and stanza.name == "presence" then + room = muclib.new_room(room_jid); + return room:handle_first_presence(origin, stanza); + elseif stanza.attr.type ~= "error" then + origin.send(st.error_reply(stanza, "cancel", "item-not-found")); + return true; + else + return; + end end - room._occupants[nick] = nil; - end + return room[method](room, origin, stanza); + end, -2) end + function shutdown_component() - if not saved then - local stanza = st.presence({type = "unavailable"}) - :tag("x", {xmlns = "http://jabber.org/protocol/muc#user"}) - :tag("item", { affiliation='none', role='none' }):up() - :tag("status", { code = "332"}):up(); - for roomjid, room in pairs(rooms) do - shutdown_room(room, stanza); - end - shutdown_room(host_room, stanza); + for room in each_room(true) do + room:save(nil, true); end end -module.unload = shutdown_component; -module:hook_global("server-stopping", shutdown_component); - --- Ad-hoc commands -module:depends("adhoc") -local t_concat = table.concat; -local keys = require "util.iterators".keys; -local adhoc_new = module:require "adhoc".new; -local adhoc_initial = require "util.adhoc".new_initial_data_form; -local dataforms_new = require "util.dataforms".new; - -local destroy_rooms_layout = dataforms_new { - title = "Destroy rooms"; - instructions = "Select the rooms to destroy"; - - { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/muc#destroy" }; - { name = "rooms", type = "list-multi", required = true, label = "Rooms to destroy:"}; -}; - -local destroy_rooms_handler = adhoc_initial(destroy_rooms_layout, function() - return { rooms = array.collect(keys(rooms)):sort() }; -end, function(fields, errors) - if errors then - local errmsg = {}; - for name, err in pairs(errors) do - errmsg[#errmsg + 1] = name .. ": " .. err; +module:hook_global("server-stopping", shutdown_component, -300); + +do -- Ad-hoc commands + module:depends "adhoc"; + local t_concat = table.concat; + local adhoc_new = module:require "adhoc".new; + local adhoc_initial = require "util.adhoc".new_initial_data_form; + local array = require "util.array"; + local dataforms_new = require "util.dataforms".new; + + local destroy_rooms_layout = dataforms_new { + title = "Destroy rooms"; + instructions = "Select the rooms to destroy"; + + { name = "FORM_TYPE", type = "hidden", value = "http://prosody.im/protocol/muc#destroy" }; + { name = "rooms", type = "list-multi", required = true, label = "Rooms to destroy:"}; + }; + + local destroy_rooms_handler = adhoc_initial(destroy_rooms_layout, function() + return { rooms = array.collect(each_room()):pluck("jid"):sort(); }; + end, function(fields, errors) + if errors then + local errmsg = {}; + for field, err in pairs(errors) do + errmsg[#errmsg + 1] = field .. ": " .. err; + end + return { status = "completed", error = { message = t_concat(errmsg, "\n") } }; end - return { status = "completed", error = { message = t_concat(errmsg, "\n") } }; - end - for _, room in ipairs(fields.rooms) do - rooms[room]:destroy(); - rooms[room] = nil; - end - return { status = "completed", info = "The following rooms were destroyed:\n"..t_concat(fields.rooms, "\n") }; -end); -local destroy_rooms_desc = adhoc_new("Destroy Rooms", "http://prosody.im/protocol/muc#destroy", destroy_rooms_handler, "admin"); + for _, room in ipairs(fields.rooms) do + get_room_from_jid(room):destroy(); + end + return { status = "completed", info = "The following rooms were destroyed:\n"..t_concat(fields.rooms, "\n") }; + end); + local destroy_rooms_desc = adhoc_new("Destroy Rooms", "http://prosody.im/protocol/muc#destroy", destroy_rooms_handler, "admin"); -module:provides("adhoc", destroy_rooms_desc); + module:provides("adhoc", destroy_rooms_desc); +end diff --git a/plugins/muc/moderated.lib.lua b/plugins/muc/moderated.lib.lua new file mode 100644 index 00000000..8354c585 --- /dev/null +++ b/plugins/muc/moderated.lib.lua @@ -0,0 +1,51 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local function get_moderated(room) + return room._data.moderated; +end + +local function set_moderated(room, moderated) + moderated = moderated and true or nil; + if get_moderated(room) == moderated then return false; end + room._data.moderated = moderated; + return true; +end + +module:hook("muc-disco#info", function(event) + event.reply:tag("feature", {var = get_moderated(event.room) and "muc_moderated" or "muc_unmoderated"}):up(); +end); + +module:hook("muc-config-form", function(event) + table.insert(event.form, { + name = "muc#roomconfig_moderatedroom"; + type = "boolean"; + label = "Make Room Moderated?"; + value = get_moderated(event.room); + }); +end, 100-4); + +module:hook("muc-config-submitted/muc#roomconfig_moderatedroom", function(event) + if set_moderated(event.room, event.value) then + event.status_codes["104"] = true; + end +end); + +module:hook("muc-get-default-role", function(event) + if event.affiliation == nil then + if get_moderated(event.room) then + return "visitor" + end + end +end, 1); + +return { + get = get_moderated; + set = set_moderated; +}; diff --git a/plugins/muc/muc.lib.lua b/plugins/muc/muc.lib.lua index 8257b0b7..b48ee7db 100644 --- a/plugins/muc/muc.lib.lua +++ b/plugins/muc/muc.lib.lua @@ -1,66 +1,34 @@ -- Prosody IM -- Copyright (C) 2008-2010 Matthew Wild -- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator -- -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- local select = select; -local pairs, ipairs = pairs, ipairs; - -local datetime = require "util.datetime"; +local pairs = pairs; +local next = next; +local setmetatable = setmetatable; local dataform = require "util.dataforms"; - +local iterators = require "util.iterators"; local jid_split = require "util.jid".split; local jid_bare = require "util.jid".bare; local jid_prep = require "util.jid".prep; +local jid_join = require "util.jid".join; +local jid_resource = require "util.jid".resource; local st = require "util.stanza"; -local log = require "util.logger".init("mod_muc"); -local t_insert, t_remove = table.insert, table.remove; -local setmetatable = setmetatable; local base64 = require "util.encodings".base64; local md5 = require "util.hashes".md5; -local muc_domain = nil; --module:get_host(); -local default_history_length, max_history_length = 20, math.huge; - ------------- -local presence_filters = {["http://jabber.org/protocol/muc"]=true;["http://jabber.org/protocol/muc#user"]=true}; -local function presence_filter(tag) - if presence_filters[tag.attr.xmlns] then - return nil; - end - return tag; -end - -local function get_filtered_presence(stanza) - return st.clone(stanza):maptags(presence_filter); -end -local kickable_error_conditions = { - ["gone"] = true; - ["internal-server-error"] = true; - ["item-not-found"] = true; - ["jid-malformed"] = true; - ["recipient-unavailable"] = true; - ["redirect"] = true; - ["remote-server-not-found"] = true; - ["remote-server-timeout"] = true; - ["service-unavailable"] = true; - ["malformed error"] = true; -}; - -local function get_error_condition(stanza) - local _, condition = stanza:get_error(); - return condition or "malformed error"; -end +local log = module._log; -local function is_kickable_error(stanza) - local cond = get_error_condition(stanza); - return kickable_error_conditions[cond] and cond; -end ------------ +local occupant_lib = module:require "muc/occupant" +local muc_util = module:require "muc/util"; +local is_kickable_error = muc_util.is_kickable_error; +local valid_roles, valid_affiliations = muc_util.valid_roles, muc_util.valid_affiliations; local room_mt = {}; room_mt.__index = room_mt; @@ -69,37 +37,147 @@ function room_mt:__tostring() return "MUC room ("..self.jid..")"; end +function room_mt.save() + -- overriden by mod_muc.lua +end + +function room_mt:get_occupant_jid(real_jid) + return self._jid_nick[real_jid] +end + function room_mt:get_default_role(affiliation) - if affiliation == "owner" or affiliation == "admin" then + local role = module:fire_event("muc-get-default-role", { + room = self; + affiliation = affiliation; + affiliation_rank = valid_affiliations[affiliation or "none"]; + }); + role = role ~= "none" and role or nil; -- coerces `role == false` to `nil` + return role, valid_roles[role or "none"]; +end +module:hook("muc-get-default-role", function(event) + if event.affiliation_rank >= valid_affiliations.admin then return "moderator"; - elseif affiliation == "member" then + elseif event.affiliation_rank >= valid_affiliations.none then return "participant"; - elseif not affiliation then - if not self:get_members_only() then - return self:get_moderated() and "visitor" or "participant"; - end end +end); + +--- Occupant functions +function room_mt:new_occupant(bare_real_jid, nick) + local occupant = occupant_lib.new(bare_real_jid, nick); + local affiliation = self:get_affiliation(bare_real_jid); + occupant.role = self:get_default_role(affiliation); + return occupant; +end + +function room_mt:get_occupant_by_nick(nick) + local occupant = self._occupants[nick]; + if occupant == nil then return nil end + return occupant_lib.copy(occupant); end -function room_mt:broadcast_presence(stanza, sid, code, nick) - stanza = get_filtered_presence(stanza); - local occupant = self._occupants[stanza.attr.from]; - stanza:tag("x", {xmlns='http://jabber.org/protocol/muc#user'}) - :tag("item", {affiliation=occupant.affiliation or "none", role=occupant.role or "none", nick=nick}):up(); - if code then - stanza:tag("status", {code=code}):up(); - end - self:broadcast_except_nick(stanza, stanza.attr.from); - local me = self._occupants[stanza.attr.from]; - if me then - stanza:tag("status", {code='110'}):up(); - stanza.attr.to = sid; - self:_route_stanza(stanza); +do + local function next_copied_occupant(occupants, occupant_jid) + local next_occupant_jid, raw_occupant = next(occupants, occupant_jid); + if next_occupant_jid == nil then return nil end + return next_occupant_jid, occupant_lib.copy(raw_occupant); end + -- FIXME Explain what 'read_only' is supposed to be + function room_mt:each_occupant(read_only) -- luacheck: ignore 212 + return next_copied_occupant, self._occupants, nil; + end +end + +function room_mt:has_occupant() + return next(self._occupants, nil) ~= nil end -function room_mt:broadcast_message(stanza, historic) + +function room_mt:get_occupant_by_real_jid(real_jid) + local occupant_jid = self:get_occupant_jid(real_jid); + if occupant_jid == nil then return nil end + return self:get_occupant_by_nick(occupant_jid); +end + +function room_mt:save_occupant(occupant) + occupant = occupant_lib.copy(occupant); -- So that occupant can be modified more + local id = occupant.nick + + -- Need to maintain _jid_nick secondary index + local old_occupant = self._occupants[id]; + if old_occupant then + for real_jid in old_occupant:each_session() do + self._jid_nick[real_jid] = nil; + end + end + + local has_live_session = false + if occupant.role ~= nil then + for real_jid, presence in occupant:each_session() do + if presence.attr.type == nil then + has_live_session = true + self._jid_nick[real_jid] = occupant.nick; + end + end + if not has_live_session then + -- Has no live sessions left; they have left the room. + occupant.role = nil + end + end + if not has_live_session then + occupant = nil + end + self._occupants[id] = occupant +end + +function room_mt:route_to_occupant(occupant, stanza) local to = stanza.attr.to; - local room_jid = self.jid; + for jid in occupant:each_session() do + stanza.attr.to = jid; + self:route_stanza(stanza); + end + stanza.attr.to = to; +end + +-- actor is the attribute table +local function add_item(x, affiliation, role, jid, nick, actor_nick, actor_jid, reason) + x:tag("item", {affiliation = affiliation; role = role; jid = jid; nick = nick;}) + if actor_nick or actor_jid then + x:tag("actor", {nick = actor_nick; jid = actor_jid;}):up() + end + if reason then + x:tag("reason"):text(reason):up() + end + x:up(); + return x +end + +-- actor is (real) jid +function room_mt:build_item_list(occupant, x, is_anonymous, nick, actor_nick, actor_jid, reason) + local affiliation = self:get_affiliation(occupant.bare_jid) or "none"; + local role = occupant.role or "none"; + if is_anonymous then + add_item(x, affiliation, role, nil, nick, actor_nick, actor_jid, reason); + else + for real_jid in occupant:each_session() do + add_item(x, affiliation, role, real_jid, nick, actor_nick, actor_jid, reason); + end + end + return x +end + +function room_mt:broadcast_message(stanza) + if module:fire_event("muc-broadcast-message", {room = self, stanza = stanza}) then + return true; + end + self:broadcast(stanza); + return true; +end + +-- Strip delay tags claiming to be from us +module:hook("muc-occupant-groupchat", function (event) + local stanza = event.stanza; + local room = event.room; + local room_jid = room.jid; stanza:maptags(function (child) if child.name == "delay" and child.attr["xmlns"] == "urn:xmpp:delay" then @@ -114,509 +192,563 @@ function room_mt:broadcast_message(stanza, historic) end return child; end) +end); - for occupant, o_data in pairs(self._occupants) do - for jid in pairs(o_data.sessions) do - stanza.attr.to = jid; - self:_route_stanza(stanza); +-- Broadcast a stanza to all occupants in the room. +-- optionally checks conditional called with (nick, occupant) +function room_mt:broadcast(stanza, cond_func) + for nick, occupant in self:each_occupant() do + if cond_func == nil or cond_func(nick, occupant) then + self:route_to_occupant(occupant, stanza) end end - stanza.attr.to = to; - if historic then -- add to history - return self:save_to_history(stanza) - end end -function room_mt:save_to_history(stanza) - local history = self._data['history']; - if not history then history = {}; self._data['history'] = history; end - stanza = st.clone(stanza); - stanza.attr.to = ""; - local stamp = datetime.datetime(); - stanza:tag("delay", {xmlns = "urn:xmpp:delay", from = self.jid, stamp = stamp}):up(); -- XEP-0203 - stanza:tag("x", {xmlns = "jabber:x:delay", from = self.jid, stamp = datetime.legacy()}):up(); -- XEP-0091 (deprecated) - local entry = { stanza = stanza, stamp = stamp }; - t_insert(history, entry); - while #history > (self._data.history_length or default_history_length) do t_remove(history, 1) end + +local function can_see_real_jids(whois, occupant) + if whois == "anyone" then + return true; + elseif whois == "moderators" then + return valid_roles[occupant.role or "none"] >= valid_roles.moderator; + end end -function room_mt:broadcast_except_nick(stanza, nick) - for rnick, occupant in pairs(self._occupants) do - if rnick ~= nick then - for jid in pairs(occupant.sessions) do - stanza.attr.to = jid; - self:_route_stanza(stanza); + +-- Broadcasts an occupant's presence to the whole room +-- Takes the x element that goes into the stanzas +function room_mt:publicise_occupant_status(occupant, x, nick, actor, reason) + local base_x = x.base or x; + -- Build real jid and (optionally) occupant jid template presences + local base_presence do + -- Try to use main jid's presence + local pr = occupant:get_presence(); + if pr and (pr.attr.type ~= "unavailable" and occupant.role ~= nil) then + base_presence = st.clone(pr); + else -- user is leaving but didn't send a leave presence. make one for them + base_presence = st.presence {from = occupant.nick; type = "unavailable";}; + end + end + + -- Fire event (before full_p and anon_p are created) + local event = { + room = self; stanza = base_presence; x = base_x; + occupant = occupant; nick = nick; actor = actor; + reason = reason; + } + module:fire_event("muc-broadcast-presence", event); + + -- Allow muc-broadcast-presence listeners to change things + nick = event.nick; + actor = event.actor; + reason = event.reason; + + local whois = self:get_whois(); + + local actor_nick; + if actor then + actor_nick = jid_resource(self:get_occupant_jid(actor)); + end + + local full_p, full_x; + local function get_full_p() + if full_p == nil then + full_x = st.clone(x.full or base_x); + self:build_item_list(occupant, full_x, false, nick, actor_nick, actor, reason); + full_p = st.clone(base_presence):add_child(full_x); + end + return full_p, full_x; + end + + local anon_p, anon_x; + local function get_anon_p() + if anon_p == nil then + anon_x = st.clone(x.anon or base_x); + self:build_item_list(occupant, anon_x, true, nick, actor_nick, nil, reason); + anon_p = st.clone(base_presence):add_child(anon_x); + end + return anon_p, anon_x; + end + + local self_p, self_x; + if can_see_real_jids(whois, occupant) then + self_p, self_x = get_full_p(); + else + -- Can always see your own full jids + -- But not allowed to see actor's + self_x = st.clone(x.self or base_x); + self:build_item_list(occupant, self_x, false, nick, actor_nick, nil, reason); + self_p = st.clone(base_presence):add_child(self_x); + end + + -- General populance + for occupant_nick, n_occupant in self:each_occupant() do + if occupant_nick ~= occupant.nick then + local pr; + if can_see_real_jids(whois, n_occupant) then + pr = get_full_p(); + elseif occupant.bare_jid == n_occupant.bare_jid then + pr = self_p; + else + pr = get_anon_p(); end + self:route_to_occupant(n_occupant, pr); end end -end -function room_mt:send_occupant_list(to) - local current_nick = self._jid_nick[to]; - for occupant, o_data in pairs(self._occupants) do - if occupant ~= current_nick then - local pres = get_filtered_presence(o_data.sessions[o_data.jid]); - pres.attr.to, pres.attr.from = to, occupant; - pres:tag("x", {xmlns='http://jabber.org/protocol/muc#user'}) - :tag("item", {affiliation=o_data.affiliation or "none", role=o_data.role or "none"}):up(); - self:_route_stanza(pres); + -- Presences for occupant itself + self_x:tag("status", {code = "110";}):up(); + if occupant.role == nil then + -- They get an unavailable + self:route_to_occupant(occupant, self_p); + else + -- use their own presences as templates + for full_jid, pr in occupant:each_session() do + pr = st.clone(pr); + pr.attr.to = full_jid; + pr:add_child(self_x); + self:route_stanza(pr); end end end -function room_mt:send_history(to, stanza) - local history = self._data['history']; -- send discussion history - if history then - local x_tag = stanza and stanza:get_child("x", "http://jabber.org/protocol/muc"); - local history_tag = x_tag and x_tag:get_child("history", "http://jabber.org/protocol/muc"); - - local maxchars = history_tag and tonumber(history_tag.attr.maxchars); - if maxchars then maxchars = math.floor(maxchars); end - - local maxstanzas = math.floor(history_tag and tonumber(history_tag.attr.maxstanzas) or #history); - if not history_tag then maxstanzas = 20; end - - local seconds = history_tag and tonumber(history_tag.attr.seconds); - if seconds then seconds = datetime.datetime(os.time() - math.floor(seconds)); end - - local since = history_tag and history_tag.attr.since; - if since then since = datetime.parse(since); since = since and datetime.datetime(since); end - if seconds and (not since or since < seconds) then since = seconds; end - - local n = 0; - local charcount = 0; - - for i=#history,1,-1 do - local entry = history[i]; - if maxchars then - if not entry.chars then - entry.stanza.attr.to = ""; - entry.chars = #tostring(entry.stanza); - end - charcount = charcount + entry.chars + #to; - if charcount > maxchars then break; end + +function room_mt:send_occupant_list(to, filter) + local to_bare = jid_bare(to); + local is_anonymous = false; + local whois = self:get_whois(); + if whois ~= "anyone" then + local affiliation = self:get_affiliation(to); + if affiliation ~= "admin" and affiliation ~= "owner" then + local occupant = self:get_occupant_by_real_jid(to); + if not (occupant and can_see_real_jids(whois, occupant)) then + is_anonymous = true; end - if since and since > entry.stamp then break; end - if n + 1 > maxstanzas then break; end - n = n + 1; end - for i=#history-n+1,#history do - local msg = history[i].stanza; - msg.attr.to = to; - self:_route_stanza(msg); + end + for occupant_jid, occupant in self:each_occupant() do + if filter == nil or filter(occupant_jid, occupant) then + local x = st.stanza("x", {xmlns='http://jabber.org/protocol/muc#user'}); + self:build_item_list(occupant, x, is_anonymous and to_bare ~= occupant.bare_jid); -- can always see your own jids + local pres = st.clone(occupant:get_presence()); + pres.attr.to = to; + pres:add_child(x); + self:route_stanza(pres); end end end -function room_mt:send_subject(to) - self:_route_stanza(st.message({type='groupchat', from=self._data['subject_from'] or self.jid, to=to}):tag("subject"):text(self._data['subject'])); -end function room_mt:get_disco_info(stanza) - local count = 0; for _ in pairs(self._occupants) do count = count + 1; end - local reply = st.reply(stanza):query("http://jabber.org/protocol/disco#info") - :tag("identity", {category="conference", type="text", name=self:get_name()}):up() - :tag("feature", {var="http://jabber.org/protocol/muc"}):up() - :tag("feature", {var="http://jabber.org/protocol/muc#stable_id"}):up() - :tag("feature", {var=self:get_password() and "muc_passwordprotected" or "muc_unsecured"}):up() - :tag("feature", {var=self:get_moderated() and "muc_moderated" or "muc_unmoderated"}):up() - :tag("feature", {var=self:get_members_only() and "muc_membersonly" or "muc_open"}):up() - :tag("feature", {var=self:get_persistent() and "muc_persistent" or "muc_temporary"}):up() - :tag("feature", {var=self:get_hidden() and "muc_hidden" or "muc_public"}):up() - :tag("feature", {var=self._data.whois ~= "anyone" and "muc_semianonymous" or "muc_nonanonymous"}):up() - ; - local dataform = dataform.new({ - { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/muc#roominfo" }, - { name = "muc#roominfo_description", label = "Description", value = "" }, - { name = "muc#roominfo_occupants", label = "Number of occupants", value = "" } - }); - local formdata = { - ["muc#roominfo_description"] = self:get_description(), - ["muc#roominfo_occupants"] = tostring(count), + local reply = st.reply(stanza):query("http://jabber.org/protocol/disco#info"); + local form = dataform.new { + { name = "FORM_TYPE", type = "hidden", value = "http://jabber.org/protocol/muc#roominfo" }; }; - module:fire_event("muc-disco#info", { room = self, reply = reply, form = dataform, formdata = formdata }); - reply:add_child(dataform:form(formdata, 'result')) + local formdata = {}; + module:fire_event("muc-disco#info", {room = self; reply = reply; form = form, formdata = formdata ;}); + reply:add_child(form:form(formdata, "result")); return reply; end -function room_mt:get_disco_items(stanza) +module:hook("muc-disco#info", function(event) + event.reply:tag("feature", {var = "http://jabber.org/protocol/muc"}):up(); + event.reply:tag("feature", {var = "http://jabber.org/protocol/muc#stable_id"}):up(); +end); +module:hook("muc-disco#info", function(event) + table.insert(event.form, { name = "muc#roominfo_occupants", label = "Number of occupants" }); + event.formdata["muc#roominfo_occupants"] = tostring(iterators.count(event.room:each_occupant())); +end); + +function room_mt:get_disco_items(stanza) -- luacheck: ignore 212 return st.reply(stanza):query("http://jabber.org/protocol/disco#items"); end -function room_mt:set_subject(current_nick, subject) - if subject == "" then subject = nil; end - self._data['subject'] = subject; - self._data['subject_from'] = current_nick; - if self.save then self:save(); end - local msg = st.message({type='groupchat', from=current_nick}) - :tag('subject'):text(subject):up(); - self:broadcast_message(msg, false); - return true; -end -local function build_unavailable_presence_from_error(stanza) +function room_mt:handle_kickable(origin, stanza) -- luacheck: ignore 212 + local real_jid = stanza.attr.from; + local occupant = self:get_occupant_by_real_jid(real_jid); + if occupant == nil then return nil; end local type, condition, text = stanza:get_error(); local error_message = "Kicked: "..(condition and condition:gsub("%-", " ") or "presence error"); if text then error_message = error_message..": "..text; end - return st.presence({type='unavailable', from=stanza.attr.from, to=stanza.attr.to}) - :tag('status'):text(error_message); + occupant:set_session(real_jid, st.presence({type="unavailable"}) + :tag('status'):text(error_message)); + self:save_occupant(occupant); + local x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user";}) + :tag("status", {code = "307"}):up() + :tag("status", {code = "333"}) + self:publicise_occupant_status(occupant, x); + if occupant.jid == real_jid then -- Was last session + module:fire_event("muc-occupant-left", {room = self; nick = occupant.nick; occupant = occupant;}); + end + return true; end -function room_mt:set_name(name) - if name == "" or type(name) ~= "string" or name == (jid_split(self.jid)) then name = nil; end - if self._data.name ~= name then - self._data.name = name; - if self.save then self:save(true); end +-- Give the room creator owner affiliation +module:hook("muc-room-pre-create", function(event) + event.room:set_affiliation(true, jid_bare(event.stanza.attr.from), "owner"); +end, -1); + +-- check if user is banned +module:hook("muc-occupant-pre-join", function(event) + local room, stanza = event.room, event.stanza; + local affiliation = room:get_affiliation(stanza.attr.from); + if affiliation == "outcast" then + local reply = st.error_reply(stanza, "auth", "forbidden"):up(); + reply.tags[1].attr.code = "403"; + event.origin.send(reply:tag("x", {xmlns = "http://jabber.org/protocol/muc"})); + return true; end -end -function room_mt:get_name() - return self._data.name or jid_split(self.jid); -end -function room_mt:set_description(description) - if description == "" or type(description) ~= "string" then description = nil; end - if self._data.description ~= description then - self._data.description = description; - if self.save then self:save(true); end +end, -10); + +function room_mt:handle_first_presence(origin, stanza) + if not stanza:get_child("x", "http://jabber.org/protocol/muc") then + module:log("debug", "Room creation without <x>, possibly desynced"); + + origin.send(st.error_reply(stanza, "cancel", "item-not-found")); + return true; end -end -function room_mt:get_description() - return self._data.description; -end -function room_mt:set_password(password) - if password == "" or type(password) ~= "string" then password = nil; end - if self._data.password ~= password then - self._data.password = password; - if self.save then self:save(true); end + + local real_jid = stanza.attr.from; + local dest_jid = stanza.attr.to; + local bare_jid = jid_bare(real_jid); + if module:fire_event("muc-room-pre-create", { + room = self; + origin = origin; + stanza = stanza; + }) then return true; end + local is_first_dest_session = true; + local dest_occupant = self:new_occupant(bare_jid, dest_jid); + + local orig_nick = dest_occupant.nick; + if module:fire_event("muc-occupant-pre-join", { + room = self; + origin = origin; + stanza = stanza; + is_first_session = is_first_dest_session; + is_new_room = true; + occupant = dest_occupant; + }) then return true; end + local nick_changed = orig_nick ~= dest_occupant.nick; + + dest_occupant:set_session(real_jid, stanza); + local dest_x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user";}); + dest_x:tag("status", {code = "201"}):up(); + if self:get_whois() == "anyone" then + dest_x:tag("status", {code = "100"}):up(); end -end -function room_mt:get_password() - return self._data.password; -end -function room_mt:set_moderated(moderated) - moderated = moderated and true or nil; - if self._data.moderated ~= moderated then - self._data.moderated = moderated; - if self.save then self:save(true); end + local self_x; + if nick_changed then + self_x = st.clone(dest_x); + self_x:tag("status", {code = "210"}):up(); end + self:save_occupant(dest_occupant); + + self:publicise_occupant_status(dest_occupant, {base = dest_x, self = self_x}); + + module:fire_event("muc-occupant-joined", { + room = self; + nick = dest_occupant.nick; + occupant = dest_occupant; + stanza = stanza; + origin = origin; + }); + module:fire_event("muc-occupant-session-new", { + room = self; + nick = dest_occupant.nick; + occupant = dest_occupant; + stanza = stanza; + origin = origin; + jid = real_jid; + }); + module:fire_event("muc-room-created", { + room = self; + creator = dest_occupant; + stanza = stanza; + origin = origin; + }); + return true; end -function room_mt:get_moderated() - return self._data.moderated; -end -function room_mt:set_members_only(members_only) - members_only = members_only and true or nil; - if self._data.members_only ~= members_only then - self._data.members_only = members_only; - if self.save then self:save(true); end + +function room_mt:handle_normal_presence(origin, stanza) + local type = stanza.attr.type; + local real_jid = stanza.attr.from; + local bare_jid = jid_bare(real_jid); + local orig_occupant = self:get_occupant_by_real_jid(real_jid); + local muc_x = stanza:get_child("x", "http://jabber.org/protocol/muc"); + + if orig_occupant == nil and not muc_x and stanza.attr.type == nil then + module:log("debug", "Attempted join without <x>, possibly desynced"); + origin.send(st.error_reply(stanza, "cancel", "item-not-found", "You must join the room before sending presence updates")); + return true; end -end -function room_mt:get_members_only() - return self._data.members_only; -end -function room_mt:set_persistent(persistent) - persistent = persistent and true or nil; - if self._data.persistent ~= persistent then - self._data.persistent = persistent; - if self.save then self:save(true); end + + local is_first_dest_session; + local dest_occupant; + if type == "unavailable" then + if orig_occupant == nil then return true; end -- Unavailable from someone not in the room + -- dest_occupant = nil + elseif orig_occupant and orig_occupant.nick == stanza.attr.to then -- Just a presence update + log("debug", "presence update for %s from session %s", orig_occupant.nick, real_jid); + dest_occupant = orig_occupant; + else + local dest_jid = stanza.attr.to; + dest_occupant = self:get_occupant_by_nick(dest_jid); + if dest_occupant == nil then + log("debug", "no occupant found for %s; creating new occupant object for %s", dest_jid, real_jid); + is_first_dest_session = true; + dest_occupant = self:new_occupant(bare_jid, dest_jid); + else + is_first_dest_session = false; + end end -end -function room_mt:get_persistent() - return self._data.persistent; -end -function room_mt:set_hidden(hidden) - hidden = hidden and true or nil; - if self._data.hidden ~= hidden then - self._data.hidden = hidden; - if self.save then self:save(true); end + local is_last_orig_session; + if orig_occupant ~= nil then + -- Is there are least 2 sessions? + local iter, ob, last = orig_occupant:each_session(); + is_last_orig_session = iter(ob, iter(ob, last)) == nil; end -end -function room_mt:get_hidden() - return self._data.hidden; -end -function room_mt:get_public() - return not self:get_hidden(); -end -function room_mt:set_public(public) - return self:set_hidden(not public); -end -function room_mt:set_changesubject(changesubject) - changesubject = changesubject and true or nil; - if self._data.changesubject ~= changesubject then - self._data.changesubject = changesubject; - if self.save then self:save(true); end + + local orig_nick = dest_occupant and dest_occupant.nick; + + local event, event_name = { + room = self; + origin = origin; + stanza = stanza; + is_first_session = is_first_dest_session; + is_last_session = is_last_orig_session; + }; + if orig_occupant == nil then + event_name = "muc-occupant-pre-join"; + event.occupant = dest_occupant; + elseif dest_occupant == nil then + event_name = "muc-occupant-pre-leave"; + event.occupant = orig_occupant; + else + event_name = "muc-occupant-pre-change"; + event.orig_occupant = orig_occupant; + event.dest_occupant = dest_occupant; end -end -function room_mt:get_changesubject() - return self._data.changesubject; -end -function room_mt:get_historylength() - return self._data.history_length or default_history_length; -end -function room_mt:set_historylength(length) - length = math.min(tonumber(length) or default_history_length, max_history_length or math.huge); - if length == default_history_length then - length = nil; + if module:fire_event(event_name, event) then return true; end + + local nick_changed = dest_occupant and orig_nick ~= dest_occupant.nick; + + -- Check for nick conflicts + if dest_occupant ~= nil and not is_first_dest_session + and bare_jid ~= jid_bare(dest_occupant.bare_jid) then + -- new nick or has different bare real jid + log("debug", "%s couldn't join due to nick conflict: %s", real_jid, dest_occupant.nick); + local reply = st.error_reply(stanza, "cancel", "conflict"):up(); + reply.tags[1].attr.code = "409"; + origin.send(reply:tag("x", {xmlns = "http://jabber.org/protocol/muc"})); + return true; end - self._data.history_length = length; -end + -- Send presence stanza about original occupant + if orig_occupant ~= nil and orig_occupant ~= dest_occupant then + local orig_x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user";}); + local dest_nick; + if dest_occupant == nil then -- Session is leaving + log("debug", "session %s is leaving occupant %s", real_jid, orig_occupant.nick); + if is_last_orig_session then + orig_occupant.role = nil; + end + orig_occupant:set_session(real_jid, stanza); + else + log("debug", "session %s is changing from occupant %s to %s", real_jid, orig_occupant.nick, dest_occupant.nick); + local generated_unavail = st.presence {from = orig_occupant.nick, to = real_jid, type = "unavailable"}; + orig_occupant:set_session(real_jid, generated_unavail); + dest_nick = jid_resource(dest_occupant.nick); + if not is_first_dest_session then -- User is swapping into another pre-existing session + log("debug", "session %s is swapping into multisession %s, showing it leave.", real_jid, dest_occupant.nick); + -- Show the other session leaving + local x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user";}); + add_item(x, self:get_affiliation(bare_jid), "none"); + local pr = st.presence{from = dest_occupant.nick, to = real_jid, type = "unavailable"} + :tag("status"):text("you are joining pre-existing session " .. dest_nick):up() + :add_child(x); + self:route_stanza(pr); + end + if is_first_dest_session and is_last_orig_session then -- Normal nick change + log("debug", "no sessions in %s left; publicly marking as nick change", orig_occupant.nick); + orig_x:tag("status", {code = "303";}):up(); + else -- The session itself always needs to see a nick change + -- don't want to get our old nick's available presence, + -- so remove our session from there, and manually generate an unavailable + orig_occupant:remove_session(real_jid); + log("debug", "generating nick change for %s", real_jid); + local x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user";}); + -- self:build_item_list(orig_occupant, x, false, dest_nick); -- COMPAT: clients get confused if they see other items besides their own + add_item(x, self:get_affiliation(bare_jid), orig_occupant.role, real_jid, dest_nick); + x:tag("status", {code = "303";}):up(); + x:tag("status", {code = "110";}):up(); + self:route_stanza(generated_unavail:add_child(x)); + dest_nick = nil; -- set dest_nick to nil; so general populance doesn't see it for whole orig_occupant + end + end -local valid_whois = { moderators = true, anyone = true }; + self:save_occupant(orig_occupant); + self:publicise_occupant_status(orig_occupant, orig_x, dest_nick); -function room_mt:set_whois(whois) - if valid_whois[whois] and self._data.whois ~= whois then - self._data.whois = whois; - if self.save then self:save(true); end + if is_last_orig_session then + module:fire_event("muc-occupant-left", { + room = self; + nick = orig_occupant.nick; + occupant = orig_occupant; + origin = origin; + stanza = stanza; + }); + end end -end -function room_mt:get_whois() - return self._data.whois; -end + if dest_occupant ~= nil then + dest_occupant:set_session(real_jid, stanza); + local dest_x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user";}); + if orig_occupant == nil and self:get_whois() == "anyone" then + dest_x:tag("status", {code = "100"}):up(); + end + self:save_occupant(dest_occupant); -local function construct_stanza_id(room, stanza) - local from_jid, to_nick = stanza.attr.from, stanza.attr.to; - local from_nick = room._jid_nick[from_jid]; - local occupant = room._occupants[to_nick]; - local to_jid = occupant.jid; + if orig_occupant == nil or muc_x then + -- Send occupant list to newly joined or desynced user + self:send_occupant_list(real_jid, function(nick, occupant) -- luacheck: ignore 212 + -- Don't include self + return occupant:get_presence(real_jid) == nil; + end) + end + local self_x; + if nick_changed then + self_x = st.clone(dest_x); + self_x:tag("status", {code="210"}):up(); + end + self:publicise_occupant_status(dest_occupant, {base=dest_x,self=self_x}); - return from_nick, to_jid, base64.encode(to_jid.."\0"..stanza.attr.id.."\0"..md5(from_jid)); -end -local function deconstruct_stanza_id(room, stanza) - local from_jid_possiblybare, to_nick = stanza.attr.from, stanza.attr.to; - local from_jid, id, to_jid_hash = (base64.decode(stanza.attr.id) or ""):match("^(%Z+)%z(%Z*)%z(.+)$"); - local from_nick = room._jid_nick[from_jid]; - - if not(from_nick) then return; end - if not(from_jid_possiblybare == from_jid or from_jid_possiblybare == jid_bare(from_jid)) then return; end - - local occupant = room._occupants[to_nick]; - for to_jid in pairs(occupant and occupant.sessions or {}) do - if md5(to_jid) == to_jid_hash then - return from_nick, to_jid, id; + if orig_occupant ~= nil and orig_occupant ~= dest_occupant and not is_last_orig_session then -- If user is swapping and wasn't last original session + log("debug", "session %s split nicks; showing %s rejoining", real_jid, orig_occupant.nick); + -- Show the original nick joining again + local pr = st.clone(orig_occupant:get_presence()); + pr.attr.to = real_jid; + local x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user";}); + self:build_item_list(orig_occupant, x, false); + -- TODO: new status code to inform client this was the multi-session it left? + pr:add_child(x); + self:route_stanza(pr); + end + + if orig_occupant == nil or muc_x then + if is_first_dest_session then + module:fire_event("muc-occupant-joined", { + room = self; + nick = dest_occupant.nick; + occupant = dest_occupant; + stanza = stanza; + origin = origin; + }); + end + module:fire_event("muc-occupant-session-new", { + room = self; + nick = dest_occupant.nick; + occupant = dest_occupant; + stanza = stanza; + origin = origin; + jid = real_jid; + }); end end + return true; end +function room_mt:handle_presence_to_occupant(origin, stanza) + local type = stanza.attr.type; + if type == "error" then -- error, kick em out! + return self:handle_kickable(origin, stanza) + elseif type == nil or type == "unavailable" then + return self:handle_normal_presence(origin, stanza); + elseif type ~= 'result' then -- bad type + if type ~= 'visible' and type ~= 'invisible' then -- COMPAT ejabberd can broadcast or forward XEP-0018 presences + origin.send(st.error_reply(stanza, "modify", "bad-request")); -- FIXME correct error? + end + end + return true; +end -function room_mt:handle_to_occupant(origin, stanza) -- PM, vCards, etc +function room_mt:handle_iq_to_occupant(origin, stanza) local from, to = stanza.attr.from, stanza.attr.to; - local room = jid_bare(to); - local current_nick = self._jid_nick[from]; local type = stanza.attr.type; - log("debug", "room: %s, current_nick: %s, stanza: %s", room or "nil", current_nick or "nil", stanza:top_tag()); - if (select(2, jid_split(from)) == muc_domain) then error("Presence from the MUC itself!!!"); end - if stanza.name == "presence" then - local pr = get_filtered_presence(stanza); - pr.attr.from = current_nick; - if type == "error" then -- error, kick em out! - if current_nick then - log("debug", "kicking %s from %s", current_nick, room); - self:handle_to_occupant(origin, build_unavailable_presence_from_error(stanza)); - end - elseif type == "unavailable" then -- unavailable - if current_nick then - log("debug", "%s leaving %s", current_nick, room); - self._jid_nick[from] = nil; - local occupant = self._occupants[current_nick]; - local new_jid = next(occupant.sessions); - if new_jid == from then new_jid = next(occupant.sessions, new_jid); end - if new_jid then - local jid = occupant.jid; - occupant.jid = new_jid; - occupant.sessions[from] = nil; - pr.attr.to = from; - pr:tag("x", {xmlns='http://jabber.org/protocol/muc#user'}) - :tag("item", {affiliation=occupant.affiliation or "none", role='none'}):up() - :tag("status", {code='110'}):up(); - self:_route_stanza(pr); - if jid ~= new_jid then - pr = st.clone(occupant.sessions[new_jid]) - :tag("x", {xmlns='http://jabber.org/protocol/muc#user'}) - :tag("item", {affiliation=occupant.affiliation or "none", role=occupant.role or "none"}); - pr.attr.from = current_nick; - self:broadcast_except_nick(pr, current_nick); - end - else - occupant.role = 'none'; - self:broadcast_presence(pr, from); - self._occupants[current_nick] = nil; - end - end - elseif not type then -- available - if current_nick then - --if #pr == #stanza or current_nick ~= to then -- commented because google keeps resending directed presence - if current_nick == to then -- simple presence - log("debug", "%s broadcasted presence", current_nick); - self._occupants[current_nick].sessions[from] = pr; - self:broadcast_presence(pr, from); - else -- change nick - -- a MUC service MUST NOT allow empty or invisible Room Nicknames - -- (i.e., Room Nicknames that consist only of one or more space characters). - if not select(3, jid_split(to)):find("[^ ]") then -- resourceprep turns all whitespace into 0x20 - module:log("debug", "Rejecting invisible nickname"); - origin.send(st.error_reply(stanza, "cancel", "not-allowed")); - return; - end - local occupant = self._occupants[current_nick]; - local is_multisession = next(occupant.sessions, next(occupant.sessions)); - if self._occupants[to] or is_multisession then - log("debug", "%s couldn't change nick", current_nick); - local reply = st.error_reply(stanza, "cancel", "conflict"):up(); - reply.tags[1].attr.code = "409"; - origin.send(reply:tag("x", {xmlns = "http://jabber.org/protocol/muc"})); - else - local data = self._occupants[current_nick]; - local to_nick = select(3, jid_split(to)); - if to_nick then - log("debug", "%s (%s) changing nick to %s", current_nick, data.jid, to); - local p = st.presence({type='unavailable', from=current_nick}); - self:broadcast_presence(p, from, '303', to_nick); - self._occupants[current_nick] = nil; - self._occupants[to] = data; - self._jid_nick[from] = to; - pr.attr.from = to; - self._occupants[to].sessions[from] = pr; - self:broadcast_presence(pr, from); - else - --TODO malformed-jid - end - end - end - --else -- possible rejoin - -- log("debug", "%s had connection replaced", current_nick); - -- self:handle_to_occupant(origin, st.presence({type='unavailable', from=from, to=to}) - -- :tag('status'):text('Replaced by new connection'):up()); -- send unavailable - -- self:handle_to_occupant(origin, stanza); -- resend available - --end - else -- enter room - -- a MUC service MUST NOT allow empty or invisible Room Nicknames - -- (i.e., Room Nicknames that consist only of one or more space characters). - if not select(3, jid_split(to)):find("[^ ]") then -- resourceprep turns all whitespace into 0x20 - module:log("debug", "Rejecting invisible nickname"); - origin.send(st.error_reply(stanza, "cancel", "not-allowed")); - return; - end - local new_nick = to; - local is_merge; - if self._occupants[to] then - if jid_bare(from) ~= jid_bare(self._occupants[to].jid) then - new_nick = nil; - end - is_merge = true; - end - local password = stanza:get_child("x", "http://jabber.org/protocol/muc"); - password = password and password:get_child("password", "http://jabber.org/protocol/muc"); - password = password and password[1] ~= "" and password[1]; - if self:get_password() and self:get_password() ~= password then - log("debug", "%s couldn't join due to invalid password: %s", from, to); - local reply = st.error_reply(stanza, "auth", "not-authorized"):up(); - reply.tags[1].attr.code = "401"; - origin.send(reply:tag("x", {xmlns = "http://jabber.org/protocol/muc"})); - elseif not new_nick then - log("debug", "%s couldn't join due to nick conflict: %s", from, to); - local reply = st.error_reply(stanza, "cancel", "conflict"):up(); - reply.tags[1].attr.code = "409"; - origin.send(reply:tag("x", {xmlns = "http://jabber.org/protocol/muc"})); - else - log("debug", "%s joining as %s", from, to); - if not next(self._affiliations) then -- new room, no owners - self._affiliations[jid_bare(from)] = "owner"; - if self.locked and not stanza:get_child("x", "http://jabber.org/protocol/muc") then - self.locked = nil; -- Older groupchat protocol doesn't lock - end - elseif self.locked then -- Deny entry - module:log("debug", "Room is locked, denying entry"); - origin.send(st.error_reply(stanza, "cancel", "item-not-found")); - return; - end - local affiliation = self:get_affiliation(from); - local role = self:get_default_role(affiliation) - if role then -- new occupant - if not is_merge then - self._occupants[to] = {affiliation=affiliation, role=role, jid=from, sessions={[from]=get_filtered_presence(stanza)}}; - else - self._occupants[to].sessions[from] = get_filtered_presence(stanza); - end - self._jid_nick[from] = to; - self:send_occupant_list(from); - pr.attr.from = to; - pr:tag("x", {xmlns='http://jabber.org/protocol/muc#user'}) - :tag("item", {affiliation=affiliation or "none", role=role or "none"}):up(); - if not is_merge then - self:broadcast_except_nick(pr, to); - end - pr:tag("status", {code='110'}):up(); - if self._data.whois == 'anyone' then - pr:tag("status", {code='100'}):up(); - end - if self.locked then - pr:tag("status", {code='201'}):up(); - end - pr.attr.to = from; - self:_route_stanza(pr); - self:send_history(from, stanza); - self:send_subject(from); - elseif not affiliation then -- registration required for entering members-only room - local reply = st.error_reply(stanza, "auth", "registration-required"):up(); - reply.tags[1].attr.code = "407"; - origin.send(reply:tag("x", {xmlns = "http://jabber.org/protocol/muc"})); - else -- banned - local reply = st.error_reply(stanza, "auth", "forbidden"):up(); - reply.tags[1].attr.code = "403"; - origin.send(reply:tag("x", {xmlns = "http://jabber.org/protocol/muc"})); - end + local id = stanza.attr.id; + local occupant = self:get_occupant_by_nick(to); + if (type == "error" or type == "result") then + do -- deconstruct_stanza_id + if not occupant then return nil; end + local from_jid, orig_id, to_jid_hash = (base64.decode(id) or ""):match("^(%Z+)%z(%Z*)%z(.+)$"); + if not(from == from_jid or from == jid_bare(from_jid)) then return nil; end + local from_occupant_jid = self:get_occupant_jid(from_jid); + if from_occupant_jid == nil then return nil; end + local session_jid + for to_jid in occupant:each_session() do + if md5(to_jid) == to_jid_hash then + session_jid = to_jid; + break; end end - elseif type ~= 'result' then -- bad type - if type ~= 'visible' and type ~= 'invisible' then -- COMPAT ejabberd can broadcast or forward XEP-0018 presences - origin.send(st.error_reply(stanza, "modify", "bad-request")); -- FIXME correct error? - end + if session_jid == nil then return nil; end + stanza.attr.from, stanza.attr.to, stanza.attr.id = from_occupant_jid, session_jid, orig_id; end - elseif not current_nick then -- not in room - if (type == "error" or type == "result") and stanza.name == "iq" then - local id = stanza.attr.id; - stanza.attr.from, stanza.attr.to, stanza.attr.id = deconstruct_stanza_id(self, stanza); - if stanza.attr.id then - self:_route_stanza(stanza); - end - stanza.attr.from, stanza.attr.to, stanza.attr.id = from, to, id; - elseif type ~= "error" then + log("debug", "%s sent private iq stanza to %s (%s)", from, to, stanza.attr.to); + self:route_stanza(stanza); + stanza.attr.from, stanza.attr.to, stanza.attr.id = from, to, id; + return true; + else -- Type is "get" or "set" + local current_nick = self:get_occupant_jid(from); + if not current_nick then origin.send(st.error_reply(stanza, "cancel", "not-acceptable")); + return true; end - elseif stanza.name == "message" and type == "groupchat" then -- groupchat messages not allowed in PM - origin.send(st.error_reply(stanza, "modify", "bad-request")); - elseif current_nick and stanza.name == "message" and type == "error" and is_kickable_error(stanza) then - log("debug", "%s kicked from %s for sending an error message", current_nick, self.jid); - self:handle_to_occupant(origin, build_unavailable_presence_from_error(stanza)); -- send unavailable - else -- private stanza - local o_data = self._occupants[to]; - if o_data then - log("debug", "%s sent private stanza to %s (%s)", from, to, o_data.jid); - if stanza.name == "iq" then - local id = stanza.attr.id; - if stanza.attr.type == "get" or stanza.attr.type == "set" then - stanza.attr.from, stanza.attr.to, stanza.attr.id = construct_stanza_id(self, stanza); - else - stanza.attr.from, stanza.attr.to, stanza.attr.id = deconstruct_stanza_id(self, stanza); - end - if type == 'get' and stanza.tags[1].attr.xmlns == 'vcard-temp' then - stanza.attr.to = jid_bare(stanza.attr.to); - end - if stanza.attr.id then - self:_route_stanza(stanza); - end - stanza.attr.from, stanza.attr.to, stanza.attr.id = from, to, id; - else -- message - stanza:tag("x", { xmlns = "http://jabber.org/protocol/muc#user" }):up(); - stanza.attr.from = current_nick; - for jid in pairs(o_data.sessions) do - stanza.attr.to = jid; - self:_route_stanza(stanza); - end - stanza.attr.from, stanza.attr.to = from, to; - end - elseif type ~= "error" and type ~= "result" then -- recipient not in room + if not occupant then -- recipient not in room origin.send(st.error_reply(stanza, "cancel", "item-not-found", "Recipient not in room")); + return true; + end + do -- construct_stanza_id + stanza.attr.id = base64.encode(occupant.jid.."\0"..stanza.attr.id.."\0"..md5(from)); + end + stanza.attr.from, stanza.attr.to = current_nick, occupant.jid; + log("debug", "%s sent private iq stanza to %s (%s)", from, to, occupant.jid); + if stanza.tags[1].attr.xmlns == 'vcard-temp' then + stanza.attr.to = jid_bare(stanza.attr.to); end + self:route_stanza(stanza); + stanza.attr.from, stanza.attr.to, stanza.attr.id = from, to, id; + return true; end end +function room_mt:handle_message_to_occupant(origin, stanza) + local from, to = stanza.attr.from, stanza.attr.to; + local current_nick = self:get_occupant_jid(from); + local type = stanza.attr.type; + if not current_nick then -- not in room + if type ~= "error" then + origin.send(st.error_reply(stanza, "cancel", "not-acceptable")); + end + return true; + end + if type == "groupchat" then -- groupchat messages not allowed in PM + origin.send(st.error_reply(stanza, "modify", "bad-request")); + return true; + elseif type == "error" and is_kickable_error(stanza) then + log("debug", "%s kicked from %s for sending an error message", current_nick, self.jid); + return self:handle_kickable(origin, stanza); -- send unavailable + end + + local o_data = self:get_occupant_by_nick(to); + if not o_data then + origin.send(st.error_reply(stanza, "cancel", "item-not-found", "Recipient not in room")); + return true; + end + log("debug", "%s sent private message stanza to %s (%s)", from, to, o_data.jid); + stanza:tag("x", { xmlns = "http://jabber.org/protocol/muc#user" }):up(); + stanza.attr.from = current_nick; + self:route_to_occupant(o_data, stanza) + -- TODO: Remove x tag? + stanza.attr.from = from; + return true; +end + function room_mt:send_form(origin, stanza) origin.send(st.reply(stanza):query("http://jabber.org/protocol/muc#owner") :add_child(self:get_form_layout(stanza.attr.from):form()) @@ -631,361 +763,389 @@ function room_mt:get_form_layout(actor) name = 'FORM_TYPE', type = 'hidden', value = 'http://jabber.org/protocol/muc#roomconfig' - }, - { - name = 'muc#roomconfig_roomname', - type = 'text-single', - label = 'Name', - value = self:get_name() or "", - }, - { - name = 'muc#roomconfig_roomdesc', - type = 'text-single', - label = 'Description', - value = self:get_description() or "", - }, - { - name = 'muc#roomconfig_persistentroom', - type = 'boolean', - label = 'Make Room Persistent?', - value = self:get_persistent() - }, - { - name = 'muc#roomconfig_publicroom', - type = 'boolean', - label = 'Make Room Publicly Searchable?', - value = not self:get_hidden() - }, - { - name = 'muc#roomconfig_changesubject', - type = 'boolean', - label = 'Allow Occupants to Change Subject?', - value = self:get_changesubject() - }, - { - name = 'muc#roomconfig_whois', - type = 'list-single', - label = 'Who May Discover Real JIDs?', - value = { - { value = 'moderators', label = 'Moderators Only', default = self._data.whois == 'moderators' }, - { value = 'anyone', label = 'Anyone', default = self._data.whois == 'anyone' } - } - }, - { - name = 'muc#roomconfig_roomsecret', - type = 'text-private', - label = 'Password', - value = self:get_password() or "", - }, - { - name = 'muc#roomconfig_moderatedroom', - type = 'boolean', - label = 'Make Room Moderated?', - value = self:get_moderated() - }, - { - name = 'muc#roomconfig_membersonly', - type = 'boolean', - label = 'Make Room Members-Only?', - value = self:get_members_only() - }, - { - name = 'muc#roomconfig_historylength', - type = 'text-single', - label = 'Maximum Number of History Messages Returned by Room', - value = tostring(self:get_historylength()) } }); return module:fire_event("muc-config-form", { room = self, actor = actor, form = form }) or form; end function room_mt:process_form(origin, stanza) - local query = stanza.tags[1]; - local form; - for _, tag in ipairs(query.tags) do if tag.name == "x" and tag.attr.xmlns == "jabber:x:data" then form = tag; break; end end - if not form then origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); return; end - if form.attr.type == "cancel" then origin.send(st.reply(stanza)); return; end - if form.attr.type ~= "submit" then origin.send(st.error_reply(stanza, "cancel", "bad-request", "Not a submitted form")); return; end - - if form.tags[1] == nil then - -- instant room - if self.save then self:save(true); end + local form = stanza.tags[1]:get_child("x", "jabber:x:data"); + if form.attr.type == "cancel" then origin.send(st.reply(stanza)); - return true; - end + elseif form.attr.type == "submit" then + local fields, errors, present; + if form.tags[1] == nil then -- Instant room + fields, present = {}, {}; + else + fields, errors, present = self:get_form_layout(stanza.attr.from):data(form); + if fields.FORM_TYPE ~= "http://jabber.org/protocol/muc#roomconfig" then + origin.send(st.error_reply(stanza, "cancel", "bad-request", "Form is not of type room configuration")); + return true; + end + end - local fields, errors, present = self:get_form_layout(stanza.attr.from):data(form); - if fields.FORM_TYPE ~= "http://jabber.org/protocol/muc#roomconfig" then - origin.send(st.error_reply(stanza, "cancel", "bad-request", "Form is not of type room configuration")); - return; - end + local event = {room = self; origin = origin; stanza = stanza; fields = fields; status_codes = {};}; + function event.update_option(name, field, allowed) + local new = fields[field]; + if new == nil then return; end + if allowed and not allowed[new] then return; end + if new == self["get_"..name](self) then return; end + event.status_codes["104"] = true; + self["set_"..name](self, new); + return true; + end + module:fire_event("muc-config-submitted", event); + for submitted_field in pairs(present) do + event.field, event.value = submitted_field, fields[submitted_field]; + module:fire_event("muc-config-submitted/"..submitted_field, event); + end + event.field, event.value = nil, nil; - local changed = {}; + self:save(true); + origin.send(st.reply(stanza)); - local function handle_option(name, field, allowed) - if not present[field] then return; end - local new = fields[field]; - if allowed and not allowed[new] then return; end - if new == self["get_"..name](self) then return; end - changed[name] = true; - self["set_"..name](self, new); + if next(event.status_codes) then + local msg = st.message({type='groupchat', from=self.jid}) + :tag('x', {xmlns='http://jabber.org/protocol/muc#user'}) + for code in pairs(event.status_codes) do + msg:tag("status", {code = code;}):up(); + end + msg:up(); + self:broadcast_message(msg); + end + else + origin.send(st.error_reply(stanza, "cancel", "bad-request", "Not a submitted form")); end + return true; +end - local event = { room = self, fields = fields, changed = changed, stanza = stanza, origin = origin, update_option = handle_option }; - module:fire_event("muc-config-submitted", event); - - handle_option("name", "muc#roomconfig_roomname"); - handle_option("description", "muc#roomconfig_roomdesc"); - handle_option("persistent", "muc#roomconfig_persistentroom"); - handle_option("moderated", "muc#roomconfig_moderatedroom"); - handle_option("members_only", "muc#roomconfig_membersonly"); - handle_option("public", "muc#roomconfig_publicroom"); - handle_option("changesubject", "muc#roomconfig_changesubject"); - handle_option("historylength", "muc#roomconfig_historylength"); - handle_option("whois", "muc#roomconfig_whois", valid_whois); - handle_option("password", "muc#roomconfig_roomsecret"); - - if self.save then self:save(true); end - if self.locked then - module:fire_event("muc-room-unlocked", { room = self }); - self.locked = nil; +-- Removes everyone from the room +function room_mt:clear(x) + x = x or st.stanza("x", {xmlns='http://jabber.org/protocol/muc#user'}); + local occupants_updated = {}; + for nick, occupant in self:each_occupant() do -- luacheck: ignore 213 + occupant.role = nil; + self:save_occupant(occupant); + occupants_updated[occupant] = true; end - origin.send(st.reply(stanza)); - - if next(changed) then - local msg = st.message({type='groupchat', from=self.jid}) - :tag('x', {xmlns='http://jabber.org/protocol/muc#user'}) - :tag('status', {code = '104'}):up(); - if changed.whois then - local code = (self:get_whois() == 'moderators') and "173" or "172"; - msg.tags[1]:tag('status', {code = code}):up(); - end - self:broadcast_message(msg, false) + for occupant in pairs(occupants_updated) do + self:publicise_occupant_status(occupant, x); + module:fire_event("muc-occupant-left", { room = self; nick = occupant.nick; occupant = occupant;}); end end function room_mt:destroy(newjid, reason, password) - local pr = st.presence({type = "unavailable"}) - :tag("x", {xmlns = "http://jabber.org/protocol/muc#user"}) - :tag("item", { affiliation='none', role='none' }):up() - :tag("destroy", {jid=newjid}) - if reason then pr:tag("reason"):text(reason):up(); end - if password then pr:tag("password"):text(password):up(); end - for nick, occupant in pairs(self._occupants) do - pr.attr.from = nick; - for jid in pairs(occupant.sessions) do - pr.attr.to = jid; - self:_route_stanza(pr); - self._jid_nick[jid] = nil; + local x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user"}) + :tag("item", { affiliation='none', role='none' }):up() + :tag("destroy", {jid=newjid}); + if reason then x:tag("reason"):text(reason):up(); end + if password then x:tag("password"):text(password):up(); end + x:up(); + self:clear(x); + module:fire_event("muc-room-destroyed", { room = self }); + return true; +end + +function room_mt:handle_disco_info_get_query(origin, stanza) + origin.send(self:get_disco_info(stanza)); + return true; +end + +function room_mt:handle_disco_items_get_query(origin, stanza) + origin.send(self:get_disco_items(stanza)); + return true; +end + +function room_mt:handle_admin_query_set_command(origin, stanza) + local item = stanza.tags[1].tags[1]; + if not item then + origin.send(st.error_reply(stanza, "cancel", "bad-request")); + end + if item.attr.jid then -- Validate provided JID + item.attr.jid = jid_prep(item.attr.jid); + if not item.attr.jid then + origin.send(st.error_reply(stanza, "modify", "jid-malformed")); + return true; end - self._occupants[nick] = nil; end - self:set_persistent(false); - module:fire_event("muc-room-destroyed", { room = self }); + if not item.attr.jid and item.attr.nick then -- COMPAT Workaround for Miranda sending 'nick' instead of 'jid' when changing affiliation + local occupant = self:get_occupant_by_nick(self.jid.."/"..item.attr.nick); + if occupant then item.attr.jid = occupant.jid; end + elseif not item.attr.nick and item.attr.jid then + local nick = self:get_occupant_jid(item.attr.jid); + if nick then item.attr.nick = jid_resource(nick); end + end + local actor = stanza.attr.from; + local reason = item:get_child_text("reason"); + local success, errtype, err + if item.attr.affiliation and item.attr.jid and not item.attr.role then + success, errtype, err = self:set_affiliation(actor, item.attr.jid, item.attr.affiliation, reason); + elseif item.attr.role and item.attr.nick and not item.attr.affiliation then + success, errtype, err = self:set_role(actor, self.jid.."/"..item.attr.nick, item.attr.role, reason); + else + success, errtype, err = nil, "cancel", "bad-request"; + end + self:save(true); + if not success then + origin.send(st.error_reply(stanza, errtype, err)); + else + origin.send(st.reply(stanza)); + end return true; end -function room_mt:handle_to_room(origin, stanza) -- presence changes and groupchat messages, along with disco/etc - local type = stanza.attr.type; - local xmlns = stanza.tags[1] and stanza.tags[1].attr.xmlns; - if stanza.name == "iq" then - if xmlns == "http://jabber.org/protocol/disco#info" and type == "get" and not stanza.tags[1].attr.node then - origin.send(self:get_disco_info(stanza)); - elseif xmlns == "http://jabber.org/protocol/disco#items" and type == "get" and not stanza.tags[1].attr.node then - origin.send(self:get_disco_items(stanza)); - elseif xmlns == "http://jabber.org/protocol/muc#admin" then - local actor = stanza.attr.from; - local affiliation = self:get_affiliation(actor); - local current_nick = self._jid_nick[actor]; - local role = current_nick and self._occupants[current_nick].role or self:get_default_role(affiliation); - local item = stanza.tags[1].tags[1]; - if item and item.name == "item" then - if type == "set" then - local callback = function() origin.send(st.reply(stanza)); end - if item.attr.jid then -- Validate provided JID - item.attr.jid = jid_prep(item.attr.jid); - if not item.attr.jid then - origin.send(st.error_reply(stanza, "modify", "jid-malformed")); - return; - end - end - if not item.attr.jid and item.attr.nick then -- COMPAT Workaround for Miranda sending 'nick' instead of 'jid' when changing affiliation - local occupant = self._occupants[self.jid.."/"..item.attr.nick]; - if occupant then item.attr.jid = occupant.jid; end - elseif not item.attr.nick and item.attr.jid then - local nick = self._jid_nick[item.attr.jid]; - if nick then item.attr.nick = select(3, jid_split(nick)); end - end - local reason = item.tags[1] and item.tags[1].name == "reason" and #item.tags[1] == 1 and item.tags[1][1]; - if item.attr.affiliation and item.attr.jid and not item.attr.role then - local success, errtype, err = self:set_affiliation(actor, item.attr.jid, item.attr.affiliation, callback, reason); - if not success then origin.send(st.error_reply(stanza, errtype, err)); end - elseif item.attr.role and item.attr.nick and not item.attr.affiliation then - local success, errtype, err = self:set_role(actor, self.jid.."/"..item.attr.nick, item.attr.role, callback, reason); - if not success then origin.send(st.error_reply(stanza, errtype, err)); end - else - origin.send(st.error_reply(stanza, "cancel", "bad-request")); - end - elseif type == "get" then - local _aff = item.attr.affiliation; - local _rol = item.attr.role; - if _aff and not _rol then - if affiliation == "owner" or (affiliation == "admin" and _aff ~= "owner" and _aff ~= "admin") - or (affiliation and affiliation ~= "outcast" and self:get_members_only() and self:get_whois() == "anyone") then - local reply = st.reply(stanza):query("http://jabber.org/protocol/muc#admin"); - for jid, affiliation in pairs(self._affiliations) do - if affiliation == _aff then - reply:tag("item", {affiliation = _aff, jid = jid}):up(); - end - end - origin.send(reply); - else - origin.send(st.error_reply(stanza, "auth", "forbidden")); - end - elseif _rol and not _aff then - if role == "moderator" then - -- TODO allow admins and owners not in room? Provide read-only access to everyone who can see the participants anyway? - if _rol == "none" then _rol = nil; end - local reply = st.reply(stanza):query("http://jabber.org/protocol/muc#admin"); - for occupant_jid, occupant in pairs(self._occupants) do - if occupant.role == _rol then - reply:tag("item", { - nick = select(3, jid_split(occupant_jid)), - role = _rol or "none", - affiliation = occupant.affiliation or "none", - jid = occupant.jid - }):up(); - end - end - origin.send(reply); - else - origin.send(st.error_reply(stanza, "auth", "forbidden")); - end - else - origin.send(st.error_reply(stanza, "cancel", "bad-request")); - end - end - elseif type == "set" or type == "get" then - origin.send(st.error_reply(stanza, "cancel", "bad-request")); - end - elseif xmlns == "http://jabber.org/protocol/muc#owner" and (type == "get" or type == "set") and stanza.tags[1].name == "query" then - if self:get_affiliation(stanza.attr.from) ~= "owner" then - origin.send(st.error_reply(stanza, "auth", "forbidden", "Only owners can configure rooms")); - elseif stanza.attr.type == "get" then - self:send_form(origin, stanza); - elseif stanza.attr.type == "set" then - local child = stanza.tags[1].tags[1]; - if not child then - origin.send(st.error_reply(stanza, "modify", "bad-request")); - elseif child.name == "destroy" then - local newjid = child.attr.jid; - local reason, password; - for _,tag in ipairs(child.tags) do - if tag.name == "reason" then - reason = #tag.tags == 0 and tag[1]; - elseif tag.name == "password" then - password = #tag.tags == 0 and tag[1]; - end - end - self:destroy(newjid, reason, password); - origin.send(st.reply(stanza)); - else - self:process_form(origin, stanza); - end +function room_mt:handle_admin_query_get_command(origin, stanza) + local actor = stanza.attr.from; + local affiliation = self:get_affiliation(actor); + local item = stanza.tags[1].tags[1]; + local _aff = item.attr.affiliation; + local _aff_rank = valid_affiliations[_aff or "none"]; + local _rol = item.attr.role; + if _aff and _aff_rank and not _rol then + -- You need to be at least an admin, and be requesting info about your affifiliation or lower + -- e.g. an admin can't ask for a list of owners + local affiliation_rank = valid_affiliations[affiliation or "none"]; + if affiliation_rank >= valid_affiliations.admin and affiliation_rank >= _aff_rank + or self:get_members_only() and self:get_whois() == "anyone" and affiliation_rank >= valid_affiliations.member then + local reply = st.reply(stanza):query("http://jabber.org/protocol/muc#admin"); + for jid in self:each_affiliation(_aff or "none") do + reply:tag("item", {affiliation = _aff, jid = jid}):up(); end - elseif type == "set" or type == "get" then - origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); - end - elseif stanza.name == "message" and type == "groupchat" then - local from = stanza.attr.from; - local current_nick = self._jid_nick[from]; - local occupant = self._occupants[current_nick]; - if not occupant then -- not in room - origin.send(st.error_reply(stanza, "cancel", "not-acceptable")); - elseif occupant.role == "visitor" then - origin.send(st.error_reply(stanza, "auth", "forbidden")); + origin.send(reply:up()); + return true; else - local from = stanza.attr.from; - stanza.attr.from = current_nick; - local subject = stanza:get_child_text("subject"); - if subject then - if occupant.role == "moderator" or - ( self._data.changesubject and occupant.role == "participant" ) then -- and participant - self:set_subject(current_nick, subject); - else - stanza.attr.from = from; - origin.send(st.error_reply(stanza, "auth", "forbidden")); - end - else - self:broadcast_message(stanza, self:get_historylength() > 0 and stanza:get_child("body")); - end - stanza.attr.from = from; - end - elseif stanza.name == "message" and type == "error" and is_kickable_error(stanza) then - local current_nick = self._jid_nick[stanza.attr.from]; - log("debug", "%s kicked from %s for sending an error message", current_nick, self.jid); - self:handle_to_occupant(origin, build_unavailable_presence_from_error(stanza)); -- send unavailable - elseif stanza.name == "presence" then -- hack - some buggy clients send presence updates to the room rather than their nick - local to = stanza.attr.to; - local current_nick = self._jid_nick[stanza.attr.from]; - if current_nick then - stanza.attr.to = current_nick; - self:handle_to_occupant(origin, stanza); - stanza.attr.to = to; - elseif type ~= "error" and type ~= "result" then - origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); + origin.send(st.error_reply(stanza, "auth", "forbidden")); + return true; end - elseif stanza.name == "message" and not(type == "chat" or type == "error" or type == "groupchat" or type == "headline") and #stanza.tags == 1 - and self._jid_nick[stanza.attr.from] and stanza.tags[1].name == "x" and stanza.tags[1].attr.xmlns == "http://jabber.org/protocol/muc#user" then - local x = stanza.tags[1]; - local payload = (#x.tags == 1 and x.tags[1]); - if payload and payload.name == "invite" and payload.attr.to then - local _from, _to = stanza.attr.from, stanza.attr.to; - local _invitee = jid_prep(payload.attr.to); - if _invitee then - local _reason = payload.tags[1] and payload.tags[1].name == 'reason' and #payload.tags[1].tags == 0 and payload.tags[1][1]; - local invite = st.message({from = _to, to = _invitee, id = stanza.attr.id}) - :tag('x', {xmlns='http://jabber.org/protocol/muc#user'}) - :tag('invite', {from=_from}) - :tag('reason'):text(_reason or ""):up() - :up(); - if self:get_password() then - invite:tag("password"):text(self:get_password()):up(); - end - invite:up() - :tag('x', {xmlns="jabber:x:conference", jid=_to}) -- COMPAT: Some older clients expect this - :text(_reason or "") - :up() - :tag('body') -- Add a plain message for clients which don't support invites - :text(_from..' invited you to the room '.._to..(_reason and (' ('.._reason..')') or "")) - :up(); - if self:get_members_only() and not self:get_affiliation(_invitee) then - log("debug", "%s invited %s into members only room %s, granting membership", _from, _invitee, _to); - self:set_affiliation(_from, _invitee, "member", nil, "Invited by " .. self._jid_nick[_from]) + elseif _rol and valid_roles[_rol or "none"] and not _aff then + local role = self:get_role(self:get_occupant_jid(actor)) or self:get_default_role(affiliation); + if valid_roles[role or "none"] >= valid_roles.moderator then + if _rol == "none" then _rol = nil; end + local reply = st.reply(stanza):query("http://jabber.org/protocol/muc#admin"); + -- TODO: whois check here? (though fully anonymous rooms are not supported) + for occupant_jid, occupant in self:each_occupant() do + if occupant.role == _rol then + local nick = jid_resource(occupant_jid); + self:build_item_list(occupant, reply, false, nick); end - self:_route_stanza(invite); - else - origin.send(st.error_reply(stanza, "cancel", "jid-malformed")); end + origin.send(reply:up()); + return true; else - origin.send(st.error_reply(stanza, "cancel", "bad-request")); + origin.send(st.error_reply(stanza, "auth", "forbidden")); + return true; end else - if type == "error" or type == "result" then return; end - origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); + origin.send(st.error_reply(stanza, "cancel", "bad-request")); + return true; + end +end + +function room_mt:handle_owner_query_get_to_room(origin, stanza) + if self:get_affiliation(stanza.attr.from) ~= "owner" then + origin.send(st.error_reply(stanza, "auth", "forbidden", "Only owners can configure rooms")); + return true; end + + self:send_form(origin, stanza); + return true; end +function room_mt:handle_owner_query_set_to_room(origin, stanza) + if self:get_affiliation(stanza.attr.from) ~= "owner" then + origin.send(st.error_reply(stanza, "auth", "forbidden", "Only owners can configure rooms")); + return true; + end -function room_mt:handle_stanza(origin, stanza) - local to_node, to_host, to_resource = jid_split(stanza.attr.to); - if to_resource then - self:handle_to_occupant(origin, stanza); + local child = stanza.tags[1].tags[1]; + if not child then + origin.send(st.error_reply(stanza, "modify", "bad-request")); + return true; + elseif child.name == "destroy" then + local newjid = child.attr.jid; + local reason = child:get_child_text("reason"); + local password = child:get_child_text("password"); + self:destroy(newjid, reason, password); + origin.send(st.reply(stanza)); + return true; + elseif child.name == "x" and child.attr.xmlns == "jabber:x:data" then + return self:process_form(origin, stanza); else - self:handle_to_room(origin, stanza); + origin.send(st.error_reply(stanza, "cancel", "service-unavailable")); + return true; + end +end + +function room_mt:handle_groupchat_to_room(origin, stanza) + local from = stanza.attr.from; + local occupant = self:get_occupant_by_real_jid(from); + if module:fire_event("muc-occupant-groupchat", { + room = self; origin = origin; stanza = stanza; from = from; occupant = occupant; + }) then return true; end + stanza.attr.from = occupant.nick; + self:broadcast_message(stanza); + stanza.attr.from = from; + return true; +end + +-- Role check +module:hook("muc-occupant-groupchat", function(event) + local role_rank = valid_roles[event.occupant and event.occupant.role or "none"]; + if role_rank <= valid_roles.none then + event.origin.send(st.error_reply(event.stanza, "cancel", "not-acceptable")); + return true; + elseif role_rank <= valid_roles.visitor then + event.origin.send(st.error_reply(event.stanza, "auth", "forbidden")); + return true; + end +end, 50); + +-- hack - some buggy clients send presence updates to the room rather than their nick +function room_mt:handle_presence_to_room(origin, stanza) + local current_nick = self:get_occupant_jid(stanza.attr.from); + local handled + if current_nick then + local to = stanza.attr.to; + stanza.attr.to = current_nick; + handled = self:handle_presence_to_occupant(origin, stanza); + stanza.attr.to = to; + end + return handled; +end + +-- Need visitor role or higher to invite +module:hook("muc-pre-invite", function(event) + local room, stanza = event.room, event.stanza; + local _from = stanza.attr.from; + local inviter = room:get_occupant_by_real_jid(_from); + local role = inviter and inviter.role or room:get_default_role(room:get_affiliation(_from)); + if valid_roles[role or "none"] <= valid_roles.visitor then + event.origin.send(st.error_reply(stanza, "auth", "forbidden")); + return true; + end +end); + +function room_mt:handle_mediated_invite(origin, stanza) + local payload = stanza:get_child("x", "http://jabber.org/protocol/muc#user"):get_child("invite"); + local invitee = jid_prep(payload.attr.to); + if not invitee then + origin.send(st.error_reply(stanza, "cancel", "jid-malformed")); + return true; + elseif module:fire_event("muc-pre-invite", {room = self, origin = origin, stanza = stanza}) then + return true; + end + local invite = muc_util.filter_muc_x(st.clone(stanza)); + invite.attr.from = self.jid; + invite.attr.to = invitee; + invite:tag('x', {xmlns='http://jabber.org/protocol/muc#user'}) + :tag('invite', {from = stanza.attr.from;}) + :tag('reason'):text(payload:get_child_text("reason")):up() + :up() + :up(); + if not module:fire_event("muc-invite", {room = self, stanza = invite, origin = origin, incoming = stanza}) then + self:route_stanza(invite); + end + return true; +end + +-- COMPAT: Some older clients expect this +module:hook("muc-invite", function(event) + local room, stanza = event.room, event.stanza; + local invite = stanza:get_child("x", "http://jabber.org/protocol/muc#user"):get_child("invite"); + local reason = invite:get_child_text("reason"); + stanza:tag('x', {xmlns = "jabber:x:conference"; jid = room.jid;}) + :text(reason or "") + :up(); +end); + +-- Add a plain message for clients which don't support invites +module:hook("muc-invite", function(event) + local room, stanza = event.room, event.stanza; + if not stanza:get_child("body") then + local invite = stanza:get_child("x", "http://jabber.org/protocol/muc#user"):get_child("invite"); + local reason = invite:get_child_text("reason") or ""; + stanza:tag("body") + :text(invite.attr.from.." invited you to the room "..room.jid..(reason == "" and (" ("..reason..")") or "")) + :up(); + end +end); + +function room_mt:handle_mediated_decline(origin, stanza) + local payload = stanza:get_child("x", "http://jabber.org/protocol/muc#user"):get_child("decline"); + local declinee = jid_prep(payload.attr.to); + if not declinee then + origin.send(st.error_reply(stanza, "cancel", "jid-malformed")); + return true; + elseif module:fire_event("muc-pre-decline", {room = self, origin = origin, stanza = stanza}) then + return true; + end + local decline = muc_util.filter_muc_x(st.clone(stanza)); + decline.attr.from = self.jid; + decline.attr.to = declinee; + decline:tag("x", {xmlns = "http://jabber.org/protocol/muc#user"}) + :tag("decline", {from = stanza.attr.from}) + :tag("reason"):text(payload:get_child_text("reason")):up() + :up() + :up(); + if not module:fire_event("muc-decline", {room = self, stanza = decline, origin = origin, incoming = stanza}) then + declinee = decline.attr.to; -- re-fetch, in case event modified it + local occupant + if jid_bare(declinee) == self.jid then -- declinee jid is already an in-room jid + occupant = self:get_occupant_by_nick(declinee); + end + if occupant then + self:route_to_occupant(occupant, decline); + else + self:route_stanza(decline); + end end + return true; end -function room_mt:route_stanza(stanza) end -- Replace with a routing function, e.g., function(room, stanza) core_route_stanza(origin, stanza); end +-- Add a plain message for clients which don't support declines +module:hook("muc-decline", function(event) + local room, stanza = event.room, event.stanza; + if not stanza:get_child("body") then + local decline = stanza:get_child("x", "http://jabber.org/protocol/muc#user"):get_child("decline"); + local reason = decline:get_child_text("reason") or ""; + stanza:tag("body") + :text(decline.attr.from.." declined your invite to the room "..room.jid..(reason == "" and (" ("..reason..")") or "")) + :up(); + end +end); + +function room_mt:handle_message_to_room(origin, stanza) + local type = stanza.attr.type; + if type == "groupchat" then + return self:handle_groupchat_to_room(origin, stanza) + elseif type == "error" and is_kickable_error(stanza) then + return self:handle_kickable(origin, stanza) + elseif type == nil or type == "normal" then + local x = stanza:get_child("x", "http://jabber.org/protocol/muc#user"); + if x then + local payload = x.tags[1]; + if payload == nil then --luacheck: ignore 542 + -- fallthrough + elseif payload.name == "invite" and payload.attr.to then + return self:handle_mediated_invite(origin, stanza) + elseif payload.name == "decline" and payload.attr.to then + return self:handle_mediated_decline(origin, stanza) + end + origin.send(st.error_reply(stanza, "cancel", "bad-request")); + return true; + end + + local form = stanza:get_child("x", "jabber:x:data"); + local form_type = dataform.get_type(form); + if form_type == "http://jabber.org/protocol/muc#request" then + self:handle_role_request(origin, stanza, form); + return true; + end + end +end + +function room_mt:route_stanza(stanza) -- luacheck: ignore 212 + module:send(stanza); +end function room_mt:get_affiliation(jid) local node, host, resource = jid_split(jid); @@ -994,184 +1154,184 @@ function room_mt:get_affiliation(jid) if not result and self._affiliations[host] == "outcast" then result = "outcast"; end -- host banned return result; end -function room_mt:set_affiliation(actor, jid, affiliation, callback, reason) - jid = jid_bare(jid); - if affiliation == "none" then affiliation = nil; end - if affiliation and affiliation ~= "outcast" and affiliation ~= "owner" and affiliation ~= "admin" and affiliation ~= "member" then + +-- Iterates over jid, affiliation pairs +function room_mt:each_affiliation(with_affiliation) + if not with_affiliation then + return pairs(self._affiliations); + else + return function(_affiliations, jid) + local affiliation; + repeat -- Iterate until we get a match + jid, affiliation = next(_affiliations, jid); + until jid == nil or affiliation == with_affiliation + return jid, affiliation; + end, self._affiliations, nil + end +end + +function room_mt:set_affiliation(actor, jid, affiliation, reason) + if not actor then return nil, "modify", "not-acceptable"; end; + + local node, host, resource = jid_split(jid); + if not host then return nil, "modify", "not-acceptable"; end + jid = jid_join(node, host); -- Bare + local is_host_only = node == nil; + + if valid_affiliations[affiliation or "none"] == nil then return nil, "modify", "not-acceptable"; end - if actor ~= true then + affiliation = affiliation ~= "none" and affiliation or nil; -- coerces `affiliation == false` to `nil` + + local target_affiliation = self._affiliations[jid]; -- Raw; don't want to check against host + local is_downgrade = valid_affiliations[target_affiliation or "none"] > valid_affiliations[affiliation or "none"]; + + if actor == true then + actor = nil -- So we can pass it safely to 'publicise_occupant_status' below + else local actor_affiliation = self:get_affiliation(actor); - local target_affiliation = self:get_affiliation(jid); - if target_affiliation == affiliation then -- no change, shortcut - if callback then callback(); end - return true; - end - if actor_affiliation ~= "owner" then - if affiliation == "owner" or affiliation == "admin" or actor_affiliation ~= "admin" or target_affiliation == "owner" or target_affiliation == "admin" then - return nil, "cancel", "not-allowed"; - end - elseif target_affiliation == "owner" and jid_bare(actor) == jid then -- self change - local is_last = true; - for j, aff in pairs(self._affiliations) do if j ~= jid and aff == "owner" then is_last = false; break; end end - if is_last then - return nil, "cancel", "conflict"; + if actor_affiliation == "owner" then + if jid_bare(actor) == jid then -- self change + -- need at least one owner + local is_last = true; + for j in self:each_affiliation("owner") do + if j ~= jid then is_last = false; break; end + end + if is_last then + return nil, "cancel", "conflict"; + end end + -- owners can do anything else + elseif affiliation == "owner" or affiliation == "admin" + or actor_affiliation ~= "admin" + or target_affiliation == "owner" or target_affiliation == "admin" then + -- Can't demote owners or other admins + return nil, "cancel", "not-allowed"; end end + + -- Set in 'database' self._affiliations[jid] = affiliation; + + -- Update roles local role = self:get_default_role(affiliation); - local x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user"}) - :tag("item", {affiliation=affiliation or "none", role=role or "none"}) - :tag("reason"):text(reason or ""):up() - :up(); - local presence_type = nil; + local role_rank = valid_roles[role or "none"]; + local occupants_updated = {}; -- Filled with old roles + for nick, occupant in self:each_occupant() do -- luacheck: ignore 213 + if occupant.bare_jid == jid or ( + -- Outcast can be by host. + is_host_only and affiliation == "outcast" and select(2, jid_split(occupant.bare_jid)) == host + ) then + -- need to publcize in all cases; as affiliation in <item/> has changed. + occupants_updated[occupant] = occupant.role; + if occupant.role ~= role and ( + is_downgrade or + valid_roles[occupant.role or "none"] < role_rank -- upgrade + ) then + occupant.role = role; + self:save_occupant(occupant); + end + end + end + + -- Tell the room of the new occupant affiliations+roles + local x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user"}); if not role then -- getting kicked - presence_type = "unavailable"; if affiliation == "outcast" then x:tag("status", {code="301"}):up(); -- banned else x:tag("status", {code="321"}):up(); -- affiliation change end end - -- Your own presence should have status 110 - local self_x = st.clone(x); - self_x:tag("status", {code="110"}); - local modified_nicks = {}; - for nick, occupant in pairs(self._occupants) do - if jid_bare(occupant.jid) == jid then - if not role then -- getting kicked - self._occupants[nick] = nil; - else - occupant.affiliation, occupant.role = affiliation, role; - end - for jid,pres in pairs(occupant.sessions) do -- remove for all sessions of the nick - if not role then self._jid_nick[jid] = nil; end - local p = st.clone(pres); - p.attr.from = nick; - p.attr.type = presence_type; - p.attr.to = jid; - if occupant.jid == jid then - -- Broadcast this presence to everyone else later, with the public <x> variant - local bp = st.clone(p); - bp:add_child(x); - modified_nicks[nick] = bp; - end - p:add_child(self_x); - self:_route_stanza(p); + local is_semi_anonymous = self:get_whois() == "moderators"; + for occupant, old_role in pairs(occupants_updated) do + self:publicise_occupant_status(occupant, x, nil, actor, reason); + if occupant.role == nil then + module:fire_event("muc-occupant-left", {room = self; nick = occupant.nick; occupant = occupant;}); + elseif is_semi_anonymous and + (old_role == "moderator" and occupant.role ~= "moderator") or + (old_role ~= "moderator" and occupant.role == "moderator") then -- Has gained or lost moderator status + -- Send everyone else's presences (as jid visibility has changed) + for real_jid in occupant:each_session() do + self:send_occupant_list(real_jid, function(occupant_jid, occupant) --luacheck: ignore 212 433 + return occupant.bare_jid ~= jid; + end); end end end - if self.save then self:save(); end - if callback then callback(); end - for nick,p in pairs(modified_nicks) do - p.attr.from = nick; - self:broadcast_except_nick(p, nick); - end + + self:save(true); + + module:fire_event("muc-set-affiliation", { + room = self; + actor = actor; + jid = jid; + affiliation = affiliation or "none"; + reason = reason; + previous_affiliation = target_affiliation; + in_room = next(occupants_updated) ~= nil; + }); + return true; end function room_mt:get_role(nick) - local session = self._occupants[nick]; - return session and session.role or nil; + local occupant = self:get_occupant_by_nick(nick); + return occupant and occupant.role or nil; end -function room_mt:can_set_role(actor_jid, occupant_jid, role) - local occupant = self._occupants[occupant_jid]; - if not occupant or not actor_jid then return nil, "modify", "not-acceptable"; end - if actor_jid == true then return true; end +function room_mt:set_role(actor, occupant_jid, role, reason) + if not actor then return nil, "modify", "not-acceptable"; end - local actor = self._occupants[self._jid_nick[actor_jid]]; - if actor and actor.role == "moderator" then - if occupant.affiliation ~= "owner" and occupant.affiliation ~= "admin" then - if actor.affiliation == "owner" or actor.affiliation == "admin" then - return true; - elseif occupant.role ~= "moderator" and role ~= "moderator" then - return true; - end - end + local occupant = self:get_occupant_by_nick(occupant_jid); + if not occupant then return nil, "modify", "item-not-found"; end + + if valid_roles[role or "none"] == nil then + return nil, "modify", "not-acceptable"; end - return nil, "cancel", "not-allowed"; -end -function room_mt:set_role(actor, occupant_jid, role, callback, reason) - if role == "none" then role = nil; end - if role and role ~= "moderator" and role ~= "participant" and role ~= "visitor" then return nil, "modify", "not-acceptable"; end - local allowed, err_type, err_condition = self:can_set_role(actor, occupant_jid, role); - if not allowed then return allowed, err_type, err_condition; end - local occupant = self._occupants[occupant_jid]; - local x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user"}) - :tag("item", {affiliation=occupant.affiliation or "none", nick=select(3, jid_split(occupant_jid)), role=role or "none"}) - :tag("reason"):text(reason or ""):up() - :up(); - local presence_type = nil; - if not role then -- kick - presence_type = "unavailable"; - self._occupants[occupant_jid] = nil; - for jid in pairs(occupant.sessions) do -- remove for all sessions of the nick - self._jid_nick[jid] = nil; - end - x:tag("status", {code = "307"}):up(); + role = role ~= "none" and role or nil; -- coerces `role == false` to `nil` + + if actor == true then + actor = nil -- So we can pass it safely to 'publicise_occupant_status' below else - occupant.role = role; - end - local self_x = st.clone(x); - self_x:tag("status", {code = "110"}):up(); - local bp; - for jid,pres in pairs(occupant.sessions) do -- send to all sessions of the nick - local p = st.clone(pres); - p.attr.from = occupant_jid; - p.attr.type = presence_type; - p.attr.to = jid; - if occupant.jid == jid then - bp = st.clone(p); - bp:add_child(x); + -- Can't do anything to other owners or admins + local occupant_affiliation = self:get_affiliation(occupant.bare_jid); + if occupant_affiliation == "owner" or occupant_affiliation == "admin" then + return nil, "cancel", "not-allowed"; end - p:add_child(self_x); - self:_route_stanza(p); - end - if callback then callback(); end - if bp then - self:broadcast_except_nick(bp, occupant_jid); - end - return true; -end -function room_mt:_route_stanza(stanza) - local muc_child; - local to_occupant = self._occupants[self._jid_nick[stanza.attr.to]]; - local from_occupant = self._occupants[stanza.attr.from]; - if stanza.name == "presence" then - if to_occupant and from_occupant then - if self._data.whois == 'anyone' then - muc_child = stanza:get_child("x", "http://jabber.org/protocol/muc#user"); - else - if to_occupant.role == "moderator" or jid_bare(to_occupant.jid) == jid_bare(from_occupant.jid) then - muc_child = stanza:get_child("x", "http://jabber.org/protocol/muc#user"); - end + -- If you are trying to give or take moderator role you need to be an owner or admin + if occupant.role == "moderator" or role == "moderator" then + local actor_affiliation = self:get_affiliation(actor); + if actor_affiliation ~= "owner" and actor_affiliation ~= "admin" then + return nil, "cancel", "not-allowed"; end end - end - if muc_child then - for _, item in pairs(muc_child.tags) do - if item.name == "item" then - if from_occupant == to_occupant then - item.attr.jid = stanza.attr.to; - else - item.attr.jid = from_occupant.jid; - end - end + + -- Need to be in the room and a moderator + local actor_occupant = self:get_occupant_by_real_jid(actor); + if not actor_occupant or actor_occupant.role ~= "moderator" then + return nil, "cancel", "not-allowed"; end end - self:route_stanza(stanza); - if muc_child then - for _, item in pairs(muc_child.tags) do - if item.name == "item" then - item.attr.jid = nil; - end - end + + local x = st.stanza("x", {xmlns = "http://jabber.org/protocol/muc#user"}); + if not role then + x:tag("status", {code = "307"}):up(); end + occupant.role = role; + self:save_occupant(occupant); + self:publicise_occupant_status(occupant, x, nil, actor, reason); + if role == nil then + module:fire_event("muc-occupant-left", {room = self; nick = occupant.nick; occupant = occupant;}); + end + return true; end +local whois = module:require "muc/whois"; +room_mt.get_whois = whois.get; +room_mt.set_whois = whois.set; + local _M = {}; -- module "muc" function _M.new_room(jid, config) @@ -1179,17 +1339,103 @@ function _M.new_room(jid, config) jid = jid; _jid_nick = {}; _occupants = {}; - _data = { - whois = 'moderators'; - history_length = math.min((config and config.history_length) - or default_history_length, max_history_length); - }; + _data = config or {}; _affiliations = {}; }, room_mt); end -function _M.set_max_history_length(_max_history_length) - max_history_length = _max_history_length or math.huge; +local new_format = module:get_option_boolean("new_muc_storage_format", false); + +function room_mt:freeze(live) + local frozen, state; + if new_format then + frozen = { + _jid = self.jid; + _data = self._data; + }; + for user, affiliation in pairs(self._affiliations) do + frozen[user] = affiliation; + end + else + frozen = { + jid = self.jid; + _data = self._data; + _affiliations = self._affiliations; + }; + end + if live then + state = {}; + for nick, occupant in self:each_occupant() do + state[nick] = { + bare_jid = occupant.bare_jid; + role = occupant.role; + jid = occupant.jid; + } + for jid, presence in occupant:each_session() do + state[jid] = st.preserialize(presence); + end + end + local history = self._history; + if history and history[1] ~= nil then + state._last_message = st.preserialize(history[#history].stanza); + state._last_message_at = history[#history].timestamp; + end + end + return frozen, state; +end + +function _M.restore_room(frozen, state) + local room_jid = frozen._jid or frozen.jid; + local room = _M.new_room(room_jid, frozen._data); + + if state and state._last_message and state._last_message_at then + room._history = { + { stanza = st.deserialize(state._last_message), + timestamp = state._last_message_at, }, + }; + end + + local occupants = {}; + local room_name, room_host = jid_split(room_jid); + + if frozen.jid and frozen._affiliations then + room._affiliations = frozen._affiliations; + else + for jid, data in pairs(frozen) do + local node, host, resource = jid_split(jid); + if host:sub(1,1) ~= "_" and not resource and type(data) == "string" then + -- bare jid: affiliation + room._affiliations[jid] = data; + end + end + end + for jid, data in pairs(state or frozen) do + local node, host, resource = jid_split(jid); + if node or host:sub(1,1) ~= "_" then + if host == room_host and node == room_name and resource and type(data) == "table" then + -- full room jid: bare real jid and role + local nick = jid; + local occupant = occupants[nick] or occupant_lib.new(data.bare_jid, nick); + occupant.bare_jid = data.bare_jid; + occupant.role = data.role; + occupant.jid = data.jid; -- Primary session JID + occupants[nick] = occupant; + elseif type(data) == "table" and data.name == "presence" then + -- full user jid: presence + local nick = data.attr.from; + local occupant = occupants[nick] or occupant_lib.new(nil, nick); + local presence = st.deserialize(data); + occupant:set_session(jid, presence); + occupants[nick] = occupant; + end + end + end + + for _, occupant in pairs(occupants) do + room:save_occupant(occupant); + end + + return room; end _M.room_mt = room_mt; diff --git a/plugins/muc/name.lib.lua b/plugins/muc/name.lib.lua new file mode 100644 index 00000000..2dcb979a --- /dev/null +++ b/plugins/muc/name.lib.lua @@ -0,0 +1,45 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local jid_split = require "util.jid".split; + +local function get_name(room) + return room._data.name or jid_split(room.jid); +end + +local function set_name(room, name) + if name == "" or name == (jid_split(room.jid)) then name = nil; end + if room._data.name == name then return false; end + room._data.name = name; + return true; +end + +module:hook("muc-disco#info", function(event) + event.reply:tag("identity", {category="conference", type="text", name=get_name(event.room)}):up(); +end); + +module:hook("muc-config-form", function(event) + table.insert(event.form, { + name = "muc#roomconfig_roomname"; + type = "text-single"; + label = "Name"; + value = get_name(event.room) or ""; + }); +end, 100-1); + +module:hook("muc-config-submitted/muc#roomconfig_roomname", function(event) + if set_name(event.room, event.value) then + event.status_codes["104"] = true; + end +end); + +return { + get = get_name; + set = set_name; +}; diff --git a/plugins/muc/occupant.lib.lua b/plugins/muc/occupant.lib.lua new file mode 100644 index 00000000..8fe4bbdf --- /dev/null +++ b/plugins/muc/occupant.lib.lua @@ -0,0 +1,85 @@ +local pairs = pairs; +local setmetatable = setmetatable; +local st = require "util.stanza"; +local util = module:require "muc/util"; + +local function get_filtered_presence(stanza) + return util.filter_muc_x(st.clone(stanza)); +end + +local occupant_mt = {}; +occupant_mt.__index = occupant_mt; + +local function new_occupant(bare_real_jid, nick) + return setmetatable({ + bare_jid = bare_real_jid; + nick = nick; -- in-room jid + sessions = {}; -- hash from real_jid to presence stanzas. stanzas should not be modified + role = nil; + jid = nil; -- Primary session + }, occupant_mt); +end + +-- Deep copy an occupant +local function copy_occupant(occupant) + local sessions = {}; + for full_jid, presence_stanza in pairs(occupant.sessions) do + -- Don't keep unavailable presences, as they'll accumulate; unless they're the primary session + if presence_stanza.attr.type ~= "unavailable" or full_jid == occupant.jid then + sessions[full_jid] = presence_stanza; + end + end + return setmetatable({ + bare_jid = occupant.bare_jid; + nick = occupant.nick; + sessions = sessions; + role = occupant.role; + jid = occupant.jid; + }, occupant_mt); +end + +-- finds another session to be the primary (there might not be one) +function occupant_mt:choose_new_primary() + for jid, pr in self:each_session() do + if pr.attr.type == nil then + return jid; + end + end + return nil; +end + +function occupant_mt:set_session(real_jid, presence_stanza, replace_primary) + local pr = get_filtered_presence(presence_stanza); + pr.attr.from = self.nick; + pr.attr.to = real_jid; + + self.sessions[real_jid] = pr; + if replace_primary then + self.jid = real_jid; + elseif self.jid == nil or (pr.attr.type == "unavailable" and self.jid == real_jid) then + -- Only leave an unavailable presence as primary when there are no other options + self.jid = self:choose_new_primary() or real_jid; + end +end + +function occupant_mt:remove_session(real_jid) + -- Delete original session + self.sessions[real_jid] = nil; + if self.jid == real_jid then + self.jid = self:choose_new_primary(); + end +end + +function occupant_mt:each_session() + return pairs(self.sessions) +end + +function occupant_mt:get_presence(real_jid) + return self.sessions[real_jid or self.jid] +end + +return { + new = new_occupant; + copy = copy_occupant; + mt = occupant_mt; +} diff --git a/plugins/muc/password.lib.lua b/plugins/muc/password.lib.lua new file mode 100644 index 00000000..02ecdc1a --- /dev/null +++ b/plugins/muc/password.lib.lua @@ -0,0 +1,70 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local st = require "util.stanza"; + +local function get_password(room) + return room._data.password; +end + +local function set_password(room, password) + if password == "" then password = nil; end + if room._data.password == password then return false; end + room._data.password = password; + return true; +end + +module:hook("muc-disco#info", function(event) + event.reply:tag("feature", {var = get_password(event.room) and "muc_passwordprotected" or "muc_unsecured"}):up(); +end); + +module:hook("muc-config-form", function(event) + table.insert(event.form, { + name = "muc#roomconfig_roomsecret"; + type = "text-private"; + label = "Password"; + value = get_password(event.room) or ""; + }); +end, 100-7); + +module:hook("muc-config-submitted/muc#roomconfig_roomsecret", function(event) + if set_password(event.room, event.value) then + event.status_codes["104"] = true; + end +end); + +-- Don't allow anyone to join room unless they provide the password +module:hook("muc-occupant-pre-join", function(event) + local room, stanza = event.room, event.stanza; + local password = stanza:get_child("x", "http://jabber.org/protocol/muc"); + password = password and password:get_child_text("password", "http://jabber.org/protocol/muc"); + if not password or password == "" then password = nil; end + if get_password(room) ~= password then + local from, to = stanza.attr.from, stanza.attr.to; + module:log("debug", "%s couldn't join due to invalid password: %s", from, to); + local reply = st.error_reply(stanza, "auth", "not-authorized"):up(); + reply.tags[1].attr.code = "401"; + event.origin.send(reply:tag("x", {xmlns = "http://jabber.org/protocol/muc"})); + return true; + end +end, -20); + +-- Add password to outgoing invite +module:hook("muc-invite", function(event) + local password = get_password(event.room); + if password then + local x = event.stanza:get_child("x", "http://jabber.org/protocol/muc#user"); + x:tag("password"):text(password):up(); + end +end); + +return { + get = get_password; + set = set_password; +}; diff --git a/plugins/muc/persistent.lib.lua b/plugins/muc/persistent.lib.lua new file mode 100644 index 00000000..abceafe1 --- /dev/null +++ b/plugins/muc/persistent.lib.lua @@ -0,0 +1,47 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local function get_persistent(room) + return room._data.persistent; +end + +local function set_persistent(room, persistent) + persistent = persistent and true or nil; + if get_persistent(room) == persistent then return false; end + room._data.persistent = persistent; + return true; +end + +module:hook("muc-config-form", function(event) + table.insert(event.form, { + name = "muc#roomconfig_persistentroom"; + type = "boolean"; + label = "Make Room Persistent?"; + value = get_persistent(event.room); + }); +end, 100-3); + +module:hook("muc-config-submitted/muc#roomconfig_persistentroom", function(event) + if set_persistent(event.room, event.value) then + event.status_codes["104"] = true; + end +end); + +module:hook("muc-disco#info", function(event) + event.reply:tag("feature", {var = get_persistent(event.room) and "muc_persistent" or "muc_temporary"}):up(); +end); + +module:hook("muc-room-destroyed", function(event) + set_persistent(event.room, false); +end); + +return { + get = get_persistent; + set = set_persistent; +}; diff --git a/plugins/muc/request.lib.lua b/plugins/muc/request.lib.lua new file mode 100644 index 00000000..d7fa9426 --- /dev/null +++ b/plugins/muc/request.lib.lua @@ -0,0 +1,126 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local st = require "util.stanza"; +local jid_resource = require "util.jid".resource; + +module:hook("muc-disco#info", function(event) + event.reply:tag("feature", {var = "http://jabber.org/protocol/muc#request"}):up(); +end); + +local voice_request_form = require "util.dataforms".new({ + title = "Voice Request"; + { + name = "FORM_TYPE"; + type = "hidden"; + value = "http://jabber.org/protocol/muc#request"; + }, + { + name = "muc#jid"; + type = "jid-single"; + label = "User ID"; + }, + { + name = "muc#roomnick"; + type = "text-single"; + label = "Room Nickname"; + }, + { + name = "muc#role"; + type = "list-single"; + label = "Requested Role"; + value = "participant"; + options = { + "none", + "visitor", + "participant", + "moderator", + }; + }, + { + name = "muc#request_allow"; + type = "boolean"; + label = "Grant voice to this person?"; + value = false; + } +}); + +local function handle_request(room, origin, stanza, form) + local occupant = room:get_occupant_by_real_jid(stanza.attr.from); + local fields = voice_request_form:data(form); + local event = { + room = room; + origin = origin; + stanza = stanza; + fields = fields; + occupant = occupant; + }; + if occupant.role == "moderator" then + module:log("debug", "%s responded to a voice request in %s", jid_resource(occupant.nick), room.jid); + module:fire_event("muc-voice-response", event); + else + module:log("debug", "%s requested voice in %s", jid_resource(occupant.nick), room.jid); + module:fire_event("muc-voice-request", event); + end +end + +module:hook("muc-voice-request", function(event) + if event.occupant.role == "visitor" then + local nick = jid_resource(event.occupant.nick); + local formdata = { + ["muc#jid"] = event.stanza.attr.from; + ["muc#roomnick"] = nick; + }; + + local message = st.message({ type = "normal"; from = event.room.jid }):add_child(voice_request_form:form(formdata)):up(); + + event.room:broadcast(message, function (_, occupant) + return occupant.role == "moderator"; + end); + end +end); + +module:hook("muc-voice-response", function(event) + local actor = event.stanza.attr.from; + local affected_occupant = event.room:get_occupant_by_real_jid(event.fields["muc#jid"]); + local occupant = event.occupant; + + if occupant.role ~= "moderator" then + module:log("debug", "%s tried to grant voice but wasn't a moderator", jid_resource(occupant.nick)); + return; + end + + if not event.fields["muc#request_allow"] then + module:log("debug", "%s did not grant voice", jid_resource(occupant.nick)); + return; + end + + if not affected_occupant then + module:log("debug", "%s tried to grant voice to unknown occupant %s", jid_resource(occupant.nick), event.fields["muc#jid"]); + return; + end + + if affected_occupant.role ~= "visitor" then + module:log("debug", "%s tried to grant voice to %s but they already have it", jid_resource(occupant.nick), jid_resource(occupant.jid)); + return; + end + + module:log("debug", "%s granted voice to %s", jid_resource(event.occupant.nick), jid_resource(occupant.jid)); + local ok, errtype, err = event.room:set_role(actor, affected_occupant.nick, "participant", "Voice granted"); + + if not ok then + module:log("debug", "Error granting voice: %s", err or errtype); + event.origin.send(st.error_reply(event.stanza, errtype, err)); + end +end); + + +return { + handle_request = handle_request; +}; diff --git a/plugins/muc/subject.lib.lua b/plugins/muc/subject.lib.lua new file mode 100644 index 00000000..56d8d174 --- /dev/null +++ b/plugins/muc/subject.lib.lua @@ -0,0 +1,118 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local st = require "util.stanza"; +local dt = require "util.datetime"; + +local muc_util = module:require "muc/util"; +local valid_roles = muc_util.valid_roles; + +local function create_subject_message(from, subject) + return st.message({from = from; type = "groupchat"}) + :tag("subject"):text(subject or ""):up(); +end + +local function get_changesubject(room) + return room._data.changesubject; +end + +local function set_changesubject(room, changesubject) + changesubject = changesubject and true or nil; + if get_changesubject(room) == changesubject then return false; end + room._data.changesubject = changesubject; + return true; +end + +module:hook("muc-disco#info", function (event) + table.insert(event.form, { + name = "muc#roominfo_changesubject"; + type = "boolean"; + }); + event.formdata["muc#roominfo_changesubject"] = get_changesubject(event.room); +end); + +module:hook("muc-config-form", function(event) + table.insert(event.form, { + name = "muc#roomconfig_changesubject"; + type = "boolean"; + label = "Allow Occupants to Change Subject?"; + value = get_changesubject(event.room); + }); +end, 100-8); + +module:hook("muc-config-submitted/muc#roomconfig_changesubject", function(event) + if set_changesubject(event.room, event.value) then + event.status_codes["104"] = true; + end +end); + +local function get_subject(room) + -- a <message/> stanza from the room JID (or from the occupant JID of the entity that set the subject) + return room._data.subject_from or room.jid, room._data.subject; +end + +local function send_subject(room, to, time) + local msg = create_subject_message(get_subject(room)); + msg.attr.to = to; + if time then + msg:tag("delay", { + xmlns = "urn:xmpp:delay", + from = room.jid, + stamp = dt.datetime(time); + }):up(); + end + room:route_stanza(msg); +end + +local function set_subject(room, from, subject) + if subject == "" then subject = nil; end + local old_from, old_subject = get_subject(room); + if old_subject == subject and old_from == from then return false; end + room._data.subject_from = from; + room._data.subject = subject; + room._data.subject_time = os.time(); + local msg = create_subject_message(from, subject); + room:broadcast_message(msg); + return true; +end + +-- Send subject to joining user +module:hook("muc-occupant-session-new", function(event) + send_subject(event.room, event.stanza.attr.from, event.room._data.subject_time); +end, 20); + +-- Prosody has made the decision that messages with <subject/> are exclusively subject changes +-- e.g. body will be ignored; even if the subject change was not allowed +module:hook("muc-occupant-groupchat", function(event) + local stanza = event.stanza; + local subject = stanza:get_child("subject"); + if subject then + local room = event.room; + local occupant = event.occupant; + -- Role check for subject changes + local role_rank = valid_roles[occupant and occupant.role or "none"]; + if role_rank >= valid_roles.moderator or + ( role_rank >= valid_roles.participant and get_changesubject(room) ) then -- and participant + set_subject(room, occupant.nick, subject:get_text()); + room:save(); + return true; + else + event.origin.send(st.error_reply(stanza, "auth", "forbidden", "You are not allowed to change the subject")); + return true; + end + end +end, 20); + +return { + get_changesubject = get_changesubject; + set_changesubject = set_changesubject; + get = get_subject; + set = set_subject; + send = send_subject; +}; diff --git a/plugins/muc/util.lib.lua b/plugins/muc/util.lib.lua new file mode 100644 index 00000000..16deb543 --- /dev/null +++ b/plugins/muc/util.lib.lua @@ -0,0 +1,58 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local _M = {}; + +_M.valid_affiliations = { + outcast = -1; + none = 0; + member = 1; + admin = 2; + owner = 3; +}; + +_M.valid_roles = { + none = 0; + visitor = 1; + participant = 2; + moderator = 3; +}; + +local kickable_error_conditions = { + ["gone"] = true; + ["internal-server-error"] = true; + ["item-not-found"] = true; + ["jid-malformed"] = true; + ["recipient-unavailable"] = true; + ["redirect"] = true; + ["remote-server-not-found"] = true; + ["remote-server-timeout"] = true; + ["service-unavailable"] = true; + ["malformed error"] = true; +}; +function _M.is_kickable_error(stanza) + local cond = select(2, stanza:get_error()) or "malformed error"; + return kickable_error_conditions[cond]; +end + +local muc_x_filters = { + ["http://jabber.org/protocol/muc"] = true; + ["http://jabber.org/protocol/muc#user"] = true; +} +local function muc_x_filter(tag) + if muc_x_filters[tag.attr.xmlns] then + return nil; + end + return tag; +end +function _M.filter_muc_x(stanza) + return stanza:maptags(muc_x_filter); +end + +return _M; diff --git a/plugins/muc/whois.lib.lua b/plugins/muc/whois.lib.lua new file mode 100644 index 00000000..4acf288c --- /dev/null +++ b/plugins/muc/whois.lib.lua @@ -0,0 +1,65 @@ +-- Prosody IM +-- Copyright (C) 2008-2010 Matthew Wild +-- Copyright (C) 2008-2010 Waqas Hussain +-- Copyright (C) 2014 Daurnimator +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +local valid_whois = { + moderators = true; + anyone = true; +}; + +local function get_whois(room) + return room._data.whois or "moderators"; +end + +local function set_whois(room, whois) + assert(valid_whois[whois], "Invalid whois value") + if get_whois(room) == whois then return false; end + room._data.whois = whois; + return true; +end + +module:hook("muc-disco#info", function(event) + event.reply:tag("feature", {var = get_whois(event.room) ~= "anyone" and "muc_semianonymous" or "muc_nonanonymous"}):up(); +end); + +module:hook("muc-config-form", function(event) + local whois = get_whois(event.room); + table.insert(event.form, { + name = 'muc#roomconfig_whois', + type = 'list-single', + label = 'Who May Discover Real JIDs?', + value = { + { value = 'moderators', label = 'Moderators Only', default = whois == 'moderators' }, + { value = 'anyone', label = 'Anyone', default = whois == 'anyone' } + } + }); +end, 100-9); + +module:hook("muc-config-submitted/muc#roomconfig_whois", function(event) + if set_whois(event.room, event.value) then + local code = (event.value == 'moderators') and "173" or "172"; + event.status_codes[code] = true; + end +end); + +-- Mask 'from' jid as occupant jid if room is anonymous +module:hook("muc-invite", function(event) + local room, stanza = event.room, event.stanza; + if get_whois(room) == "moderators" and room:get_default_role(room:get_affiliation(stanza.attr.to)) ~= "moderator" then + local invite = stanza:get_child("x", "http://jabber.org/protocol/muc#user"):get_child("invite"); + local occupant_jid = room:get_occupant_jid(invite.attr.from); + if occupant_jid ~= nil then -- FIXME: This will expose real jid if inviter is not in room + invite.attr.from = occupant_jid; + end + end +end, 50); + +return { + get = get_whois; + set = set_whois; +}; @@ -49,390 +49,49 @@ if #arg > 0 and arg[1] ~= "--config" then return 1; end --- Global 'prosody' object -local prosody = { events = require "util.events".new(); }; -_G.prosody = prosody; +local startup = require "util.startup"; +local async = require "util.async"; --- Check dependencies -local dependencies = require "util.dependencies"; +-- Note: it's important that this thread is not GC'd, as some C libraries +-- that are initialized here store a pointer to it ( :/ ). +local thread = async.runner(); --- Load the config-parsing module -config = require "core.configmanager" +thread:run(startup.prosody); --- -- -- -- --- Define the functions we call during startup, the --- actual startup happens right at the end, where these --- functions get called - -function read_config() - local filenames = {}; - - local filename; - if arg[1] == "--config" and arg[2] then - table.insert(filenames, arg[2]); - if CFG_CONFIGDIR then - table.insert(filenames, CFG_CONFIGDIR.."/"..arg[2]); - end - elseif os.getenv("PROSODY_CONFIG") then -- Passed by prosodyctl - table.insert(filenames, os.getenv("PROSODY_CONFIG")); - else - for _, format in ipairs(config.parsers()) do - table.insert(filenames, (CFG_CONFIGDIR or ".").."/prosody.cfg."..format); - end - end - for _,_filename in ipairs(filenames) do - filename = _filename; - local file = io.open(filename); - if file then - file:close(); - CFG_CONFIGDIR = filename:match("^(.*)[\\/][^\\/]*$"); - break; - end - end - prosody.config_file = filename - local ok, level, err = config.load(filename); - if not ok then - print("\n"); - print("**************************"); - if level == "parser" then - print("A problem occured while reading the config file "..filename); - print(""); - local err_line, err_message = tostring(err):match("%[string .-%]:(%d*): (.*)"); - if err:match("chunk has too many syntax levels$") then - print("An Include statement in a config file is including an already-included"); - print("file and causing an infinite loop. An Include statement in a config file is..."); - else - print("Error"..(err_line and (" on line "..err_line) or "")..": "..(err_message or tostring(err))); - end - print(""); - elseif level == "file" then - print("Prosody was unable to find the configuration file."); - print("We looked for: "..filename); - print("A sample config file is included in the Prosody download called prosody.cfg.lua.dist"); - print("Copy or rename it to prosody.cfg.lua and edit as necessary."); - end - print("More help on configuring Prosody can be found at http://prosody.im/doc/configure"); - print("Good luck!"); - print("**************************"); - print(""); - os.exit(1); - end -end - -function check_dependencies() - if not dependencies.check_dependencies() then - os.exit(1); - end -end - --- luacheck: globals socket server - -function load_libraries() - -- Load socket framework - -- luacheck: ignore 111/server 111/socket - socket = require "socket"; - server = require "net.server" -end - --- The global log() gets defined by loggingmanager --- luacheck: ignore 113/log - -function init_logging() - -- Initialize logging - require "core.loggingmanager" -end - -function log_dependency_warnings() - dependencies.log_warnings(); -end - -function sanity_check() - for host, host_config in pairs(config.getconfig()) do - if host ~= "*" - and host_config.enabled ~= false - and not host_config.component_module then - return; - end - end - log("error", "No enabled VirtualHost entries found in the config file."); - log("error", "At least one active host is required for Prosody to function. Exiting..."); - os.exit(1); -end - -function sandbox_require() - -- Replace require() with one that doesn't pollute _G, required - -- for neat sandboxing of modules - -- luacheck: ignore 113/getfenv 111/require - local _realG = _G; - local _real_require = require; - local getfenv = getfenv or function (f) - -- FIXME: This is a hack to replace getfenv() in Lua 5.2 - local name, env = debug.getupvalue(debug.getinfo(f or 1).func, 1); - if name == "_ENV" then - return env; - end - end - function require(...) - local curr_env = getfenv(2); - local curr_env_mt = getmetatable(curr_env); - local _realG_mt = getmetatable(_realG); - if curr_env_mt and curr_env_mt.__index and not curr_env_mt.__newindex and _realG_mt then - local old_newindex, old_index; - old_newindex, _realG_mt.__newindex = _realG_mt.__newindex, curr_env; - old_index, _realG_mt.__index = _realG_mt.__index, function (_G, k) -- luacheck: ignore 212/_G - return rawget(curr_env, k); - end; - local ret = _real_require(...); - _realG_mt.__newindex = old_newindex; - _realG_mt.__index = old_index; - return ret; - end - return _real_require(...); - end -end - -function set_function_metatable() - local mt = {}; - function mt.__index(f, upvalue) - local i, name, value = 0; - repeat - i = i + 1; - name, value = debug.getupvalue(f, i); - until name == upvalue or name == nil; - return value; - end - function mt.__newindex(f, upvalue, value) - local i, name = 0; - repeat - i = i + 1; - name = debug.getupvalue(f, i); - until name == upvalue or name == nil; - if name then - debug.setupvalue(f, i, value); - end - end - function mt.__tostring(f) - local info = debug.getinfo(f); - return ("function(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.linedefined); - end - debug.setmetatable(function() end, mt); -end - -function init_global_state() - prosody.bare_sessions = {}; - prosody.full_sessions = {}; - prosody.hosts = {}; - - -- COMPAT: These globals are deprecated - -- luacheck: ignore 111/bare_sessions 111/full_sessions 111/hosts - bare_sessions = prosody.bare_sessions; - full_sessions = prosody.full_sessions; - hosts = prosody.hosts; - - local data_path = config.get("*", "data_path") or CFG_DATADIR or "data"; - local custom_plugin_paths = config.get("*", "plugin_paths"); - if custom_plugin_paths then - local path_sep = package.config:sub(3,3); - -- path1;path2;path3;defaultpath... - CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins"); - end - prosody.paths = { source = CFG_SOURCEDIR, config = CFG_CONFIGDIR or ".", - plugins = CFG_PLUGINDIR or "plugins", data = data_path }; - - prosody.arg = _G.arg; - - prosody.platform = "unknown"; - if os.getenv("WINDIR") then - prosody.platform = "windows"; - elseif package.config:sub(1,1) == "/" then - prosody.platform = "posix"; - end - - prosody.installed = nil; - if CFG_SOURCEDIR and (prosody.platform == "windows" or CFG_SOURCEDIR:match("^/")) then - prosody.installed = true; - end - - if prosody.installed then - -- Change working directory to data path. - require "lfs".chdir(data_path); - end - - -- Function to reload the config file - function prosody.reload_config() - log("info", "Reloading configuration file"); - prosody.events.fire_event("reloading-config"); - local ok, level, err = config.load(prosody.config_file); - if not ok then - if level == "parser" then - log("error", "There was an error parsing the configuration file: %s", tostring(err)); - elseif level == "file" then - log("error", "Couldn't read the config file when trying to reload: %s", tostring(err)); - end - end - return ok, (err and tostring(level)..": "..tostring(err)) or nil; - end - - -- Function to reopen logfiles - function prosody.reopen_logfiles() - log("info", "Re-opening log files"); - prosody.events.fire_event("reopen-log-files"); - end - - -- Function to initiate prosody shutdown - function prosody.shutdown(reason, code) - log("info", "Shutting down: %s", reason or "unknown reason"); - prosody.shutdown_reason = reason; - prosody.shutdown_code = code; - prosody.events.fire_event("server-stopping", { - reason = reason; - code = code; - }); - server.setquitting(true); - end -end - -function read_version() - -- Try to determine version - local version_file = io.open((CFG_SOURCEDIR or ".").."/prosody.version"); - if version_file then - prosody.version = version_file:read("*a"):gsub("%s*$", ""); - version_file:close(); - if #prosody.version == 12 and prosody.version:match("^[a-f0-9]+$") then - prosody.version = "hg:"..prosody.version; - end - else - prosody.version = "unknown"; - end -end - -function load_secondary_libraries() - --- Load and initialise core modules - require "util.import" - require "util.xmppstream" - require "core.stanza_router" - require "core.statsmanager" - require "core.hostmanager" - require "core.portmanager" - require "core.modulemanager" - require "core.usermanager" - require "core.rostermanager" - require "core.sessionmanager" - package.loaded['core.componentmanager'] = setmetatable({},{__index=function() - log("warn", "componentmanager is deprecated: %s", debug.traceback():match("\n[^\n]*\n[ \t]*([^\n]*)")); - return function() end - end}); - - local http = require "net.http" - local config_ssl = config.get("*", "ssl") or {} - local https_client = config.get("*", "client_https_ssl") - http.default.options.sslctx = require "core.certmanager".create_context("client_https port 0", "client", - { capath = config_ssl.capath, cafile = config_ssl.cafile, verify = "peer", }, https_client); - - require "util.array" - require "util.datetime" - require "util.iterators" - require "util.timer" - require "util.helpers" - - pcall(require, "util.signal") -- Not on Windows - - -- Commented to protect us from - -- the second kind of people - --[[ - pcall(require, "remdebug.engine"); - if remdebug then remdebug.engine.start() end - ]] - - require "util.stanza" - require "util.jid" -end - -function init_data_store() - require "core.storagemanager"; -end - -function prepare_to_start() - log("info", "Prosody is using the %s backend for connection handling", server.get_backend()); - -- Signal to modules that we are ready to start - prosody.events.fire_event("server-starting"); - prosody.start_time = os.time(); -end - -function init_global_protection() - -- Catch global accesses - -- luacheck: ignore 212/t - local locked_globals_mt = { - __index = function (t, k) log("warn", "%s", debug.traceback("Attempt to read a non-existent global '"..tostring(k).."'", 2)); end; - __newindex = function (t, k, v) error("Attempt to set a global: "..tostring(k).." = "..tostring(v), 2); end; - }; - - function prosody.unlock_globals() - setmetatable(_G, nil); - end - - function prosody.lock_globals() - setmetatable(_G, locked_globals_mt); - end - - -- And lock now... - prosody.lock_globals(); -end - -function loop() +local function loop() -- Error handler for errors that make it this far local function catch_uncaught_error(err) if type(err) == "string" and err:match("interrupted!$") then return "quitting"; end - log("error", "Top-level error, please report:\n%s", tostring(err)); + prosody.log("error", "Top-level error, please report:\n%s", tostring(err)); local traceback = debug.traceback("", 2); if traceback then - log("error", "%s", traceback); + prosody.log("error", "%s", traceback); end prosody.events.fire_event("very-bad-error", {error = err, traceback = traceback}); end local sleep = require"socket".sleep; + local server = require "net.server"; while select(2, xpcall(server.loop, catch_uncaught_error)) ~= "quitting" do sleep(0.2); end end -function cleanup() - log("info", "Shutdown status: Cleaning up"); +local function cleanup() + prosody.log("info", "Shutdown status: Cleaning up"); prosody.events.fire_event("server-cleanup"); end --- Are you ready? :) --- These actions are in a strict order, as many depend on --- previous steps to have already been performed -read_config(); -init_logging(); -sanity_check(); -sandbox_require(); -set_function_metatable(); -check_dependencies(); -load_libraries(); -init_global_state(); -read_version(); -log("info", "Hello and welcome to Prosody version %s", prosody.version); -log_dependency_warnings(); -load_secondary_libraries(); -init_data_store(); -init_global_protection(); -prepare_to_start(); - -prosody.events.fire_event("server-started"); - loop(); -log("info", "Shutting down..."); +prosody.log("info", "Shutting down..."); cleanup(); prosody.events.fire_event("server-stopped"); -log("info", "Shutdown complete"); +prosody.log("info", "Shutdown complete"); -os.exit(prosody.shutdown_code) +os.exit(prosody.shutdown_code); @@ -20,8 +20,8 @@ CFG_DATADIR=CFG_DATADIR or os.getenv("PROSODY_DATADIR"); local function is_relative(path) local path_sep = package.config:sub(1,1); - return ((path_sep == "/" and path:sub(1,1) ~= "/") - or (path_sep == "\\" and (path:sub(1,1) ~= "/" and path:sub(2,3) ~= ":\\"))) + return ((path_sep == "/" and path:sub(1,1) ~= "/") + or (path_sep == "\\" and (path:sub(1,1) ~= "/" and path:sub(2,3) ~= ":\\"))) end -- Tell Lua where to find our libraries @@ -43,190 +43,12 @@ if CFG_DATADIR then end end --- Global 'prosody' object -local prosody = { - hosts = {}; - events = require "util.events".new(); - platform = "posix"; - lock_globals = function () end; - unlock_globals = function () end; - installed = CFG_SOURCEDIR ~= nil; - core_post_stanza = function () end; -- TODO: mod_router! -}; -_G.prosody = prosody; - -local dependencies = require "util.dependencies"; -if not dependencies.check_dependencies() then - os.exit(1); -end - -config = require "core.configmanager" - -local ENV_CONFIG; -do - local filenames = {}; - - local filename; - if arg[1] == "--config" and arg[2] then - table.insert(filenames, arg[2]); - if CFG_CONFIGDIR then - table.insert(filenames, CFG_CONFIGDIR.."/"..arg[2]); - end - table.remove(arg, 1); table.remove(arg, 1); - else - for _, format in ipairs(config.parsers()) do - table.insert(filenames, (CFG_CONFIGDIR or ".").."/prosody.cfg."..format); - end - end - for _,_filename in ipairs(filenames) do - filename = _filename; - local file = io.open(filename); - if file then - file:close(); - ENV_CONFIG = filename; - CFG_CONFIGDIR = filename:match("^(.*)[\\/][^\\/]*$"); - break; - end - end - local ok, level, err = config.load(filename); - if not ok then - print("\n"); - print("**************************"); - if level == "parser" then - print("A problem occured while reading the config file "..filename); - local err_line, err_message = tostring(err):match("%[string .-%]:(%d*): (.*)"); - print("Error"..(err_line and (" on line "..err_line) or "")..": "..(err_message or tostring(err))); - print(""); - elseif level == "file" then - print("Prosody was unable to find the configuration file."); - print("We looked for: "..filename); - print("A sample config file is included in the Prosody download called prosody.cfg.lua.dist"); - print("Copy or rename it to prosody.cfg.lua and edit as necessary."); - end - print("More help on configuring Prosody can be found at http://prosody.im/doc/configure"); - print("Good luck!"); - print("**************************"); - print(""); - os.exit(1); - end -end -local original_logging_config = config.get("*", "log"); -config.set("*", "log", { { levels = { min = os.getenv("PROSODYCTL_LOG_LEVEL") or "info" }, to = "console" } }); - -local data_path = config.get("*", "data_path") or CFG_DATADIR or "data"; -local custom_plugin_paths = config.get("*", "plugin_paths"); -if custom_plugin_paths then - local path_sep = package.config:sub(3,3); - -- path1;path2;path3;defaultpath... - CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins"); -end -prosody.paths = { source = CFG_SOURCEDIR, config = CFG_CONFIGDIR, - plugins = CFG_PLUGINDIR or "plugins", data = data_path }; - -if prosody.installed then - -- Change working directory to data path. - require "lfs".chdir(data_path); -end +----------- -require "core.loggingmanager" - -dependencies.log_warnings(); - --- Switch away from root and into the prosody user -- -local switched_user, current_uid; - -local want_pposix_version = "0.4.0"; -local have_pposix, pposix = pcall(require, "util.pposix"); - -if have_pposix and pposix then - if pposix._VERSION ~= want_pposix_version then - print(string.format("Unknown version (%s) of binary pposix module, expected %s", - tostring(pposix._VERSION), want_pposix_version)); return; - end - current_uid = pposix.getuid(); - local arg_root = arg[1] == "--root"; - if arg_root then table.remove(arg, 1); end - if current_uid == 0 and config.get("*", "run_as_root") ~= true and not arg_root then - -- We haz root! - local desired_user = config.get("*", "prosody_user") or "prosody"; - local desired_group = config.get("*", "prosody_group") or desired_user; - local ok, err = pposix.setgid(desired_group); - if ok then - ok, err = pposix.initgroups(desired_user); - end - if ok then - ok, err = pposix.setuid(desired_user); - if ok then - -- Yay! - switched_user = true; - end - end - if not switched_user then - -- Boo! - print("Warning: Couldn't switch to Prosody user/group '"..tostring(desired_user).."'/'"..tostring(desired_group).."': "..tostring(err)); - else - -- Make sure the Prosody user can read the config - local conf, err, errno = io.open(ENV_CONFIG); - if conf then - conf:close(); - else - print("The config file is not readable by the '"..desired_user.."' user."); - print("Prosody will not be able to read it."); - print("Error was "..err); - os.exit(1); - end - end - end - - -- Set our umask to protect data files - pposix.umask(config.get("*", "umask") or "027"); - pposix.setenv("HOME", data_path); - pposix.setenv("PROSODY_CONFIG", ENV_CONFIG); -else - print("Error: Unable to load pposix module. Check that Prosody is installed correctly.") - print("For more help send the below error to us through http://prosody.im/discuss"); - print(tostring(pposix)) - os.exit(1); -end - -local function test_writeable(filename) - local f, err = io.open(filename, "a"); - if not f then - return false, err; - end - f:close(); - return true; -end - -local unwriteable_files = {}; -if type(original_logging_config) == "string" and original_logging_config:sub(1,1) ~= "*" then - local ok, err = test_writeable(original_logging_config); - if not ok then - table.insert(unwriteable_files, err); - end -elseif type(original_logging_config) == "table" then - for _, rule in ipairs(original_logging_config) do - if rule.filename then - local ok, err = test_writeable(rule.filename); - if not ok then - table.insert(unwriteable_files, err); - end - end - end -end - -if #unwriteable_files > 0 then - print("One of more of the Prosody log files are not"); - print("writeable, please correct the errors and try"); - print("starting prosodyctl again."); - print(""); - for _, err in ipairs(unwriteable_files) do - print(err); - end - print(""); - os.exit(1); -end +local startup = require "util.startup"; +startup.prosodyctl(); +----------- local error_messages = setmetatable({ ["invalid-username"] = "The given username is invalid in a Jabber ID"; @@ -235,60 +57,21 @@ local error_messages = setmetatable({ ["no-such-user"] = "The given user does not exist on the server"; ["no-such-host"] = "The given hostname does not exist in the config"; ["unable-to-save-data"] = "Unable to store, perhaps you don't have permission?"; - ["no-pidfile"] = "There is no 'pidfile' option in the configuration file, see http://prosody.im/doc/prosodyctl#pidfile for help"; - ["invalid-pidfile"] = "The 'pidfile' option in the configuration file is not a string, see http://prosody.im/doc/prosodyctl#pidfile for help"; - ["no-posix"] = "The mod_posix module is not enabled in the Prosody config file, see http://prosody.im/doc/prosodyctl for more info"; + ["no-pidfile"] = "There is no 'pidfile' option in the configuration file, see https://prosody.im/doc/prosodyctl#pidfile for help"; + ["invalid-pidfile"] = "The 'pidfile' option in the configuration file is not a string, see https://prosody.im/doc/prosodyctl#pidfile for help"; + ["no-posix"] = "The mod_posix module is not enabled in the Prosody config file, see https://prosody.im/doc/prosodyctl for more info"; ["no-such-method"] = "This module has no commands"; ["not-running"] = "Prosody is not running"; - }, { __index = function (t,k) return "Error: "..(tostring(k):gsub("%-", " "):gsub("^.", string.upper)); end }); - -hosts = prosody.hosts; - -local function make_host(hostname) - return { - type = "local", - events = prosody.events, - modules = {}, - sessions = {}, - users = require "core.usermanager".new_null_provider(hostname) - }; -end - -for hostname, config in pairs(config.getconfig()) do - hosts[hostname] = make_host(hostname); -end + }, { __index = function (_,k) return "Error: "..(tostring(k):gsub("%-", " "):gsub("^.", string.upper)); end }); +local configmanager = require "core.configmanager"; local modulemanager = require "core.modulemanager" - local prosodyctl = require "util.prosodyctl" local socket = require "socket" - -local http = require "net.http" -local config_ssl = config.get("*", "ssl") or {} -local https_client = config.get("*", "client_https_ssl") -http.default.options.sslctx = require "core.certmanager".create_context("client_https port 0", "client", - { capath = config_ssl.capath, cafile = config_ssl.cafile, verify = "peer", }, https_client); +local dependencies = require "util.dependencies"; ----------------------- --- FIXME: Duplicate code waiting for util.startup -function read_version() - -- Try to determine version - local version_file = io.open((CFG_SOURCEDIR or ".").."/prosody.version"); - prosody.version = "unknown"; - if version_file then - prosody.version = version_file:read("*a"):gsub("%s*$", ""); - version_file:close(); - if #prosody.version == 12 and prosody.version:match("^[a-f0-9]+$") then - prosody.version = "hg:"..prosody.version; - end - else - local hg = require"util.mercurial"; - local hgid = hg.check_id(CFG_SOURCEDIR or "."); - if hgid then prosody.version = "hg:" .. hgid; end - end -end - local show_message, show_warning = prosodyctl.show_message, prosodyctl.show_warning; local show_usage = prosodyctl.show_usage; local show_yesno = prosodyctl.show_yesno; @@ -297,7 +80,7 @@ local read_password = prosodyctl.read_password; local jid_split = require "util.jid".prepped_split; -local prosodyctl_timeout = (config.get("*", "prosodyctl_timeout") or 5) * 2; +local prosodyctl_timeout = (configmanager.get("*", "prosodyctl_timeout") or 5) * 2; ----------------------- local commands = {}; local command = arg[1]; @@ -319,10 +102,10 @@ function commands.adduser(arg) return 1; end - if not hosts[host] then + if not prosody.hosts[host] then show_warning("The host '%s' is not listed in the configuration file (or is not enabled).", host) show_warning("The user will not be able to log in until this is changed."); - hosts[host] = make_host(host); + prosody.hosts[host] = startup.make_host(host); --luacheck: ignore 122 end if prosodyctl.user_exists{ user = user, host = host } then @@ -358,10 +141,10 @@ function commands.passwd(arg) return 1; end - if not hosts[host] then + if not prosody.hosts[host] then show_warning("The host '%s' is not listed in the configuration file (or is not enabled).", host) show_warning("The user will not be able to log in until this is changed."); - hosts[host] = make_host(host); + prosody.hosts[host] = startup.make_host(host); --luacheck: ignore 122 end if not prosodyctl.user_exists { user = user, host = host } then @@ -397,9 +180,9 @@ function commands.deluser(arg) return 1; end - if not hosts[host] then + if not prosody.hosts[host] then show_warning("The host '%s' is not listed in the configuration file (or is not enabled).", host) - hosts[host] = make_host(host); + prosody.hosts[host] = startup.make_host(host); --luacheck: ignore 122 end if not prosodyctl.user_exists { user = user, host = host } then @@ -427,6 +210,7 @@ function commands.start(arg) end if ret then + --luacheck: ignore 421/ret local ok, ret = prosodyctl.getpid(); if not ok then show_message("Couldn't get running Prosody's PID"); @@ -437,9 +221,10 @@ function commands.start(arg) return 1; end - local ok, ret = prosodyctl.start(); + --luacheck: ignore 411/ret + local ok, ret = prosodyctl.start(prosody.paths.source); if ok then - local daemonize = config.get("*", "daemonize"); + local daemonize = configmanager.get("*", "daemonize"); if daemonize == nil then daemonize = prosody.installed; end @@ -481,6 +266,7 @@ function commands.status(arg) end if ret then + --luacheck: ignore 421/ret local ok, ret = prosodyctl.getpid(); if not ok then show_message("Couldn't get running Prosody's PID"); @@ -491,7 +277,7 @@ function commands.status(arg) return 0; else show_message("Prosody is not running"); - if not switched_user and current_uid ~= 0 then + if not prosody.switched_user and prosody.current_uid ~= 0 then print("\nNote:") print(" You will also see this if prosodyctl is not running under"); print(" the same user account as Prosody. Try running as root (e.g. "); @@ -499,7 +285,6 @@ function commands.status(arg) end return 2 end - return 1; end function commands.stop(arg) @@ -548,28 +333,26 @@ function commands.restart(arg) end function commands.about(arg) - read_version(); if arg[1] == "--help" then show_usage([[about]], [[Show information about this Prosody installation]]); return 1; end local pwd = "."; - local lfs = require "lfs"; local array = require "util.array"; local keys = require "util.iterators".keys; local hg = require"util.mercurial"; - local relpath = config.resolve_relative_path; + local relpath = configmanager.resolve_relative_path; print("Prosody "..(prosody.version or "(unknown version)")); print(""); print("# Prosody directories"); - print("Data directory: "..relpath(pwd, data_path)); - print("Config directory: "..relpath(pwd, CFG_CONFIGDIR or ".")); - print("Source directory: "..relpath(pwd, CFG_SOURCEDIR or ".")); + print("Data directory: "..relpath(pwd, prosody.paths.data)); + print("Config directory: "..relpath(pwd, prosody.paths.config or ".")); + print("Source directory: "..relpath(pwd, prosody.paths.source or ".")); print("Plugin directories:") print(" "..(prosody.paths.plugins:gsub("([^;]+);?", function(path) - path = config.resolve_relative_path(pwd, path); + path = configmanager.resolve_relative_path(pwd, path); local hgid, hgrepo = hg.check_id(path); if not hgid and hgrepo then return path.." - "..hgrepo .."!\n "; @@ -593,15 +376,21 @@ function commands.about(arg) print(" "..path); end print(""); - local luarocks_status = (pcall(require, "luarocks.loader") and "Installed ("..(package.loaded["luarocks.cfg"].program_version or "2.x+")..")") - or (pcall(require, "luarocks.require") and "Installed (1.x)") - or "Not installed"; + local luarocks_status = "Not installed" + if pcall(require, "luarocks.loader") then + luarocks_status = "Installed (2.x+)"; + if package.loaded["luarocks.cfg"] then + luarocks_status = "Installed ("..(package.loaded["luarocks.cfg"].program_version or "2.x+")..")"; + end + elseif pcall(require, "luarocks.require") then + luarocks_status = "Installed (1.x)"; + end print("LuaRocks: ", luarocks_status); print(""); print("# Lua module versions"); local module_versions, longest_name = {}, 8; local luaevent =dependencies.softreq"luaevent"; - local ssl = dependencies.softreq"ssl"; + dependencies.softreq"ssl"; for name, module in pairs(package.loaded) do if type(module) == "table" and rawget(module, "_VERSION") and name ~= "_G" and not name:match("%.") then @@ -718,11 +507,12 @@ local function use_existing(filename) end end -local cert_basedir = CFG_DATADIR or "./certs"; +local have_pposix, pposix = pcall(require, "util.pposix"); +local cert_basedir = prosody.paths.data == "." and "./certs" or prosody.paths.data; if have_pposix and pposix.getuid() == 0 then -- FIXME should be enough to check if this directory is writable - local cert_dir = config.get("*", "certificates") or "certs"; - cert_basedir = config.resolve_relative_path(prosody.paths.config, cert_dir); + local cert_dir = configmanager.get("*", "certificates") or "certs"; + cert_basedir = configmanager.resolve_relative_path(prosody.paths.config, cert_dir); end function cert_commands.config(arg) @@ -736,7 +526,7 @@ function cert_commands.config(arg) distinguished_name = table.remove(arg); end local conf = openssl.config.new(); - conf:from_prosody(hosts, config, arg); + conf:from_prosody(prosody.hosts, configmanager, arg); if distinguished_name then local dn = {}; for k, v in distinguished_name:gmatch("/([^=/]+)=([^/]+)") do @@ -750,7 +540,7 @@ function cert_commands.config(arg) for _, k in ipairs(openssl._DN_order) do local v = conf.distinguished_name[k]; if v then - local nv; + local nv = nil; if k == "commonName" then v = arg[1] elseif k == "emailAddress" then @@ -889,7 +679,7 @@ function cert_commands.import(arg) end else for host in pairs(prosody.hosts) do - if host ~= "*" and config.get(host, "enabled") ~= false then + if host ~= "*" and configmanager.get(host, "enabled") ~= false then table.insert(hostnames, host); end end @@ -902,8 +692,8 @@ function cert_commands.import(arg) end local owner, group; if pposix.getuid() == 0 then -- We need root to change ownership - owner = config.get("*", "prosody_user") or "prosody"; - group = config.get("*", "prosody_group") or owner; + owner = configmanager.get("*", "prosody_user") or "prosody"; + group = configmanager.get("*", "prosody_group") or owner; end local cm = require "core.certmanager"; local imported = {}; @@ -961,7 +751,7 @@ function commands.cert(arg) show_message"You need to supply at least one hostname" arg = { "--help" }; end - if arg[1] ~= "--help" and not hosts[arg[1]] then + if arg[1] ~= "--help" and not prosody.hosts[arg[1]] then show_message(error_messages["no-such-host"]); return 1; end @@ -984,26 +774,26 @@ function commands.check(arg) return 1; end local what = table.remove(arg, 1); - local array, set = require "util.array", require "util.set"; + local set = require "util.set"; local it = require "util.iterators"; local ok = true; local function disabled_hosts(host, conf) return host ~= "*" and conf.enabled ~= false; end - local function enabled_hosts() return it.filter(disabled_hosts, pairs(config.getconfig())); end + local function enabled_hosts() return it.filter(disabled_hosts, pairs(configmanager.getconfig())); end if not (what == nil or what == "disabled" or what == "config" or what == "dns" or what == "certs") then show_warning("Don't know how to check '%s'. Try one of 'config', 'dns', 'certs' or 'disabled'.", what); return 1; end if not what or what == "disabled" then - local disabled_hosts = set.new(); - for host, host_options in it.filter("*", pairs(config.getconfig())) do + local disabled_hosts_set = set.new(); + for host, host_options in it.filter("*", pairs(configmanager.getconfig())) do if host_options.enabled == false then - disabled_hosts:add(host); + disabled_hosts_set:add(host); end end - if not disabled_hosts:empty() then + if not disabled_hosts_set:empty() then local msg = "Checks will be skipped for these disabled hosts: %s"; if what then msg = "These hosts are disabled: %s"; end - show_warning(msg, tostring(disabled_hosts)); + show_warning(msg, tostring(disabled_hosts_set)); if what then return 0; end print"" end @@ -1019,13 +809,13 @@ function commands.check(arg) "umask", "prosodyctl_timeout", "use_ipv6", "use_libevent", "network_settings", "network_backend", "http_default_host", }); - local config = config.getconfig(); + local config = configmanager.getconfig(); -- Check that we have any global options (caused by putting a host at the top) if it.count(it.filter("log", pairs(config["*"]))) == 0 then ok = false; print(""); print(" No global options defined. Perhaps you have put a host definition at the top") - print(" of the config file? They should be at the bottom, see http://prosody.im/doc/configure#overview"); + print(" of the config file? They should be at the bottom, see https://prosody.im/doc/configure#overview"); end if it.count(enabled_hosts()) == 0 then ok = false; @@ -1041,7 +831,7 @@ function commands.check(arg) if not config["*"].modules_enabled then print(" No global modules_enabled is set?"); local suggested_global_modules; - for host, options in enabled_hosts() do + for host, options in enabled_hosts() do --luacheck: ignore 213/host if not options.component_module and options.modules_enabled then suggested_global_modules = set.intersection(suggested_global_modules or set.new(options.modules_enabled), set.new(options.modules_enabled)); end @@ -1052,6 +842,19 @@ function commands.check(arg) end print(); end + + do -- Check for modules enabled both normally and as components + local modules = set.new(config["*"]["modules_enabled"]); + for host, options in enabled_hosts() do + local component_module = options.component_module; + if component_module and modules:contains(component_module) then + print((" mod_%s is enabled both in modules_enabled and as Component %q %q"):format(component_module, host, component_module)); + print(" This means the service is enabled on all VirtualHosts as well as the Component."); + print(" Are you sure this what you want? It may cause unexpected behaviour."); + end + end + end + -- Check for global options under hosts local global_options = set.new(it.to_array(it.keys(config["*"]))); local deprecated_global_options = set.intersection(global_options, deprecated); @@ -1077,17 +880,17 @@ function commands.check(arg) local n = it.count(misplaced_options); print(" You have "..n.." option"..(n>1 and "s " or " ").."set under "..host.." that should be"); print(" in the global section of the config file, above any VirtualHost or Component definitions,") - print(" see http://prosody.im/doc/configure#overview for more information.") + print(" see https://prosody.im/doc/configure#overview for more information.") print(""); print(" You need to move the following option"..(n>1 and "s" or "")..": "..table.concat(it.to_array(misplaced_options), ", ")); end local subdomain = host:match("^[^.]+"); if not(host_options:contains("component_module")) and (subdomain == "jabber" or subdomain == "xmpp" - or subdomain == "chat" or subdomain == "im") then + or subdomain == "chat" or subdomain == "im") then print(""); print(" Suggestion: If "..host.. " is a new host with no real users yet, consider renaming it now to"); print(" "..host:gsub("^[^.]+%.", "")..". You can use SRV records to redirect XMPP clients and servers to "..host.."."); - print(" For more information see: http://prosody.im/doc/dns"); + print(" For more information see: https://prosody.im/doc/dns"); end end local all_modules = set.new(config["*"].modules_enabled); @@ -1117,14 +920,16 @@ function commands.check(arg) print(" For more information see https://prosody.im/doc/storage"); end end - for host, config in pairs(config) do - if type(rawget(config, "storage")) == "string" and rawget(config, "default_storage") then + for host, host_config in pairs(config) do --luacheck: ignore 213/host + if type(rawget(host_config, "storage")) == "string" and rawget(host_config, "default_storage") then print(""); print(" The 'default_storage' option is not needed if 'storage' is set to a string."); break; end end - local require_encryption = set.intersection(all_options, set.new({"require_encryption", "c2s_require_encryption", "s2s_require_encryption"})):empty(); + local require_encryption = set.intersection(all_options, set.new({ + "require_encryption", "c2s_require_encryption", "s2s_require_encryption" + })):empty(); local ssl = dependencies.softreq"ssl"; if not ssl then if not require_encryption then @@ -1170,8 +975,8 @@ function commands.check(arg) local dns = require "net.dns"; local idna = require "util.encodings".idna; local ip = require "util.ip"; - local c2s_ports = set.new(config.get("*", "c2s_ports") or {5222}); - local s2s_ports = set.new(config.get("*", "s2s_ports") or {5269}); + local c2s_ports = set.new(configmanager.get("*", "c2s_ports") or {5222}); + local s2s_ports = set.new(configmanager.get("*", "s2s_ports") or {5269}); local c2s_srv_required, s2s_srv_required; if not c2s_ports:contains(5222) then @@ -1187,16 +992,20 @@ function commands.check(arg) local fqdn = socket.dns.tohostname(socket.dns.gethostname()); if fqdn then - local res = dns.lookup(idna.to_ascii(fqdn), "A"); - if res then - for _, record in ipairs(res) do - external_addresses:add(record.a); + do + local res = dns.lookup(idna.to_ascii(fqdn), "A"); + if res then + for _, record in ipairs(res) do + external_addresses:add(record.a); + end end end - local res = dns.lookup(idna.to_ascii(fqdn), "AAAA"); - if res then - for _, record in ipairs(res) do - external_addresses:add(record.aaaa); + do + local res = dns.lookup(idna.to_ascii(fqdn), "AAAA"); + if res then + for _, record in ipairs(res) do + external_addresses:add(record.aaaa); + end end end end @@ -1223,13 +1032,18 @@ function commands.check(arg) local all_targets_ok, some_targets_ok = true, false; local node, host = jid_split(jid); + local modules, component_module = modulemanager.get_modules_for_host(host); + if component_module then + modules:add(component_module); + end + local is_component = not not host_options.component_module; print("Checking DNS for "..(is_component and "component" or "host").." "..jid.."..."); if node then print("Only the domain part ("..host..") is used in DNS.") end local target_hosts = set.new(); - if not is_component then + if modules:contains("c2s") then local res = dns.lookup("_xmpp-client._tcp."..idna.to_ascii(host)..".", "SRV"); if res then for _, record in ipairs(res) do @@ -1247,20 +1061,22 @@ function commands.check(arg) end end end - local res = dns.lookup("_xmpp-server._tcp."..idna.to_ascii(host)..".", "SRV"); - if res then - for _, record in ipairs(res) do - target_hosts:add(record.srv.target); - if not s2s_ports:contains(record.srv.port) then - print(" SRV target "..record.srv.target.." contains unknown server port: "..record.srv.port); + if modules:contains("s2s") then + local res = dns.lookup("_xmpp-server._tcp."..idna.to_ascii(host)..".", "SRV"); + if res then + for _, record in ipairs(res) do + target_hosts:add(record.srv.target); + if not s2s_ports:contains(record.srv.port) then + print(" SRV target "..record.srv.target.." contains unknown server port: "..record.srv.port); + end end - end - else - if s2s_srv_required then - print(" No _xmpp-server SRV record found for "..host..", but it looks like you need one."); - all_targets_ok = false; else - target_hosts:add(host); + if s2s_srv_required then + print(" No _xmpp-server SRV record found for "..host..", but it looks like you need one."); + all_targets_ok = false; + else + target_hosts:add(host); + end end end if target_hosts:empty() then @@ -1272,12 +1088,8 @@ function commands.check(arg) target_hosts:remove("localhost"); end - local modules = set.new(it.to_array(it.values(host_options.modules_enabled or {}))) - + set.new(it.to_array(it.values(config.get("*", "modules_enabled") or {}))) - + set.new({ config.get(host, "component_module") }); - if modules:contains("proxy65") then - local proxy65_target = config.get(host, "proxy65_address") or host; + local proxy65_target = configmanager.get(host, "proxy65_address") or host; local A, AAAA = dns.lookup(idna.to_ascii(proxy65_target), "A"), dns.lookup(idna.to_ascii(proxy65_target), "AAAA"); local prob = {}; if not A then @@ -1287,41 +1099,46 @@ function commands.check(arg) table.insert(prob, "AAAA"); end if #prob > 0 then - print(" File transfer proxy "..proxy65_target.." has no "..table.concat(prob, "/").." record. Create one or set 'proxy65_address' to the correct host/IP."); + print(" File transfer proxy "..proxy65_target.." has no "..table.concat(prob, "/") + .." record. Create one or set 'proxy65_address' to the correct host/IP."); end end - for host in target_hosts do + for target_host in target_hosts do local host_ok_v4, host_ok_v6; - local res = dns.lookup(idna.to_ascii(host), "A"); - if res then - for _, record in ipairs(res) do - if external_addresses:contains(record.a) then - some_targets_ok = true; - host_ok_v4 = true; - elseif internal_addresses:contains(record.a) then - host_ok_v4 = true; - some_targets_ok = true; - print(" "..host.." A record points to internal address, external connections might fail"); - else - print(" "..host.." A record points to unknown address "..record.a); - all_targets_ok = false; + do + local res = dns.lookup(idna.to_ascii(target_host), "A"); + if res then + for _, record in ipairs(res) do + if external_addresses:contains(record.a) then + some_targets_ok = true; + host_ok_v4 = true; + elseif internal_addresses:contains(record.a) then + host_ok_v4 = true; + some_targets_ok = true; + print(" "..target_host.." A record points to internal address, external connections might fail"); + else + print(" "..target_host.." A record points to unknown address "..record.a); + all_targets_ok = false; + end end end end - local res = dns.lookup(idna.to_ascii(host), "AAAA"); - if res then - for _, record in ipairs(res) do - if external_addresses:contains(record.aaaa) then - some_targets_ok = true; - host_ok_v6 = true; - elseif internal_addresses:contains(record.aaaa) then - host_ok_v6 = true; - some_targets_ok = true; - print(" "..host.." AAAA record points to internal address, external connections might fail"); - else - print(" "..host.." AAAA record points to unknown address "..record.aaaa); - all_targets_ok = false; + do + local res = dns.lookup(idna.to_ascii(target_host), "AAAA"); + if res then + for _, record in ipairs(res) do + if external_addresses:contains(record.aaaa) then + some_targets_ok = true; + host_ok_v6 = true; + elseif internal_addresses:contains(record.aaaa) then + host_ok_v6 = true; + some_targets_ok = true; + print(" "..target_host.." AAAA record points to internal address, external connections might fail"); + else + print(" "..target_host.." AAAA record points to unknown address "..record.aaaa); + all_targets_ok = false; + end end end end @@ -1334,11 +1151,11 @@ function commands.check(arg) table.insert(bad_protos, "IPv6"); end if #bad_protos > 0 then - print(" Host "..host.." does not seem to resolve to this server ("..table.concat(bad_protos, "/")..")"); + print(" Host "..target_host.." does not seem to resolve to this server ("..table.concat(bad_protos, "/")..")"); end if host_ok_v6 and not v6_supported then - print(" Host "..host.." has AAAA records, but your version of LuaSocket does not support IPv6."); - print(" Please see http://prosody.im/doc/ipv6 for more information."); + print(" Host "..target_host.." has AAAA records, but your version of LuaSocket does not support IPv6."); + print(" Please see https://prosody.im/doc/ipv6 for more information."); end end if not all_targets_ok then @@ -1352,7 +1169,7 @@ function commands.check(arg) end if not problem_hosts:empty() then print(""); - print("For more information about DNS configuration please see http://prosody.im/doc/dns"); + print("For more information about DNS configuration please see https://prosody.im/doc/dns"); print(""); ok = false; end @@ -1383,9 +1200,9 @@ function commands.check(arg) for host in it.filter(skip_bare_jid_hosts, enabled_hosts()) do print("Checking certificate for "..host); -- First, let's find out what certificate this host uses. - local host_ssl_config = config.rawget(host, "ssl") - or config.rawget(host:match("%.(.*)"), "ssl"); - local global_ssl_config = config.rawget("*", "ssl"); + local host_ssl_config = configmanager.rawget(host, "ssl") + or configmanager.rawget(host:match("%.(.*)"), "ssl"); + local global_ssl_config = configmanager.rawget("*", "ssl"); local ok, err, ssl_config = create_context(host, "server", host_ssl_config, global_ssl_config); if not ok then print(" Error: "..err); @@ -1410,7 +1227,7 @@ function commands.check(arg) cert_ok = false else print(" Certificate: "..ssl_config.certificate) - local cert = load_cert(cert_fh:read"*a"); cert_fh = cert_fh:close(); + local cert = load_cert(cert_fh:read"*a"); cert_fh:close(); if not cert:validat(os.time()) then print(" Certificate has expired.") cert_ok = false @@ -1422,13 +1239,13 @@ function commands.check(arg) elseif not cert:validat(os.time() + 86400*31) then print(" Certificate expires within one month.") end - if config.get(host, "component_module") == nil + if configmanager.get(host, "component_module") == nil and not x509_verify_identity(host, "_xmpp-client", cert) then print(" Not valid for client connections to "..host..".") cert_ok = false end - if (not (config.get(host, "anonymous_login") - or config.get(host, "authentication") == "anonymous")) + if (not (configmanager.get(host, "anonymous_login") + or configmanager.get(host, "authentication") == "anonymous")) and not x509_verify_identity(host, "_xmpp-server", cert) then print(" Not valid for server-to-server connections to "..host..".") cert_ok = false @@ -1436,11 +1253,11 @@ function commands.check(arg) end end end - if cert_ok == false then - print("") - print("For more information about certificates please see http://prosody.im/doc/certificates"); - ok = false - end + end + if cert_ok == false then + print("") + print("For more information about certificates please see https://prosody.im/doc/certificates"); + ok = false end print("") end @@ -1454,77 +1271,93 @@ end --------------------- -if command and command:match("^mod_") then -- Is a command in a module - local module_name = command:match("^mod_(.+)"); - local ret, err = modulemanager.load("*", module_name); - if not ret then - show_message("Failed to load module '"..module_name.."': "..err); - os.exit(1); - end +local async = require "util.async"; +local server = require "net.server"; +local watchers = { + error = function (_, err) + error(err); + end; + waiting = function () + server.loop(); + end; +}; +local command_runner = async.runner(function () + if command and command:match("^mod_") then -- Is a command in a module + local module_name = command:match("^mod_(.+)"); + do + local ret, err = modulemanager.load("*", module_name); + if not ret then + show_message("Failed to load module '"..module_name.."': "..err); + os.exit(1); + end + end - table.remove(arg, 1); + table.remove(arg, 1); - local module = modulemanager.get_module("*", module_name); - if not module then - show_message("Failed to load module '"..module_name.."': Unknown error"); - os.exit(1); - end + local module = modulemanager.get_module("*", module_name); + if not module then + show_message("Failed to load module '"..module_name.."': Unknown error"); + os.exit(1); + end - if not modulemanager.module_has_method(module, "command") then - show_message("Fail: mod_"..module_name.." does not support any commands"); - os.exit(1); - end + if not modulemanager.module_has_method(module, "command") then + show_message("Fail: mod_"..module_name.." does not support any commands"); + os.exit(1); + end - local ok, ret = modulemanager.call_module_method(module, "command", arg); - if ok then - if type(ret) == "number" then - os.exit(ret); - elseif type(ret) == "string" then - show_message(ret); + local ok, ret = modulemanager.call_module_method(module, "command", arg); + if ok then + if type(ret) == "number" then + os.exit(ret); + elseif type(ret) == "string" then + show_message(ret); + end + os.exit(0); -- :) + else + show_message("Failed to execute command: "..error_messages[ret]); + os.exit(1); -- :( end - os.exit(0); -- :) - else - show_message("Failed to execute command: "..error_messages[ret]); - os.exit(1); -- :( end -end -if not commands[command] then -- Show help for all commands - function show_usage(usage, desc) - print(" "..usage); - print(" "..desc); - end + if not commands[command] then -- Show help for all commands + function show_usage(usage, desc) + print(" "..usage); + print(" "..desc); + end - print("prosodyctl - Manage a Prosody server"); - print(""); - print("Usage: "..arg[0].." COMMAND [OPTIONS]"); - print(""); - print("Where COMMAND may be one of:\n"); + print("prosodyctl - Manage a Prosody server"); + print(""); + print("Usage: "..arg[0].." COMMAND [OPTIONS]"); + print(""); + print("Where COMMAND may be one of:\n"); - local hidden_commands = require "util.set".new{ "register", "unregister", "addplugin" }; - local commands_order = { "adduser", "passwd", "deluser", "start", "stop", "restart", "reload", "about" }; + local hidden_commands = require "util.set".new{ "register", "unregister", "addplugin" }; + local commands_order = { "adduser", "passwd", "deluser", "start", "stop", "restart", "reload", "about" }; - local done = {}; + local done = {}; - for _, command_name in ipairs(commands_order) do - local command = commands[command_name]; - if command then - command{ "--help" }; - print"" - done[command_name] = true; + for _, command_name in ipairs(commands_order) do + local command_func = commands[command_name]; + if command_func then + command_func{ "--help" }; + print"" + done[command_name] = true; + end end - end - for command_name, command in pairs(commands) do - if not done[command_name] and not hidden_commands:contains(command_name) then - command{ "--help" }; - print"" - done[command_name] = true; + for command_name, command_func in pairs(commands) do + if not done[command_name] and not hidden_commands:contains(command_name) then + command_func{ "--help" }; + print"" + done[command_name] = true; + end end - end - os.exit(0); -end + os.exit(0); + end + + os.exit(commands[command]({ select(2, unpack(arg)) })); +end, watchers); -os.exit(commands[command]({ select(2, unpack(arg)) })); +command_runner:run(true); diff --git a/spec/core_configmanager_spec.lua b/spec/core_configmanager_spec.lua new file mode 100644 index 00000000..b68d2756 --- /dev/null +++ b/spec/core_configmanager_spec.lua @@ -0,0 +1,31 @@ + +local configmanager = require "core.configmanager"; + +describe("core.configmanager", function() + describe("#get()", function() + it("should work", function() + configmanager.set("example.com", "testkey", 123); + assert.are.equal(configmanager.get("example.com", "testkey"), 123, "Retrieving a set key"); + + configmanager.set("*", "testkey1", 321); + assert.are.equal(configmanager.get("*", "testkey1"), 321, "Retrieving a set global key"); + assert.are.equal(configmanager.get("example.com", "testkey1"), 321, "Retrieving a set key of undefined host, of which only a globally set one exists"); + + configmanager.set("example.com", ""); -- Creates example.com host in config + assert.are.equal(configmanager.get("example.com", "testkey1"), 321, "Retrieving a set key, of which only a globally set one exists"); + + assert.are.equal(configmanager.get(), nil, "No parameters to get()"); + assert.are.equal(configmanager.get("undefined host"), nil, "Getting for undefined host"); + assert.are.equal(configmanager.get("undefined host", "undefined key"), nil, "Getting for undefined host & key"); + end); + end); + + describe("#set()", function() + it("should work", function() + assert.are.equal(configmanager.set("*"), false, "Set with no key"); + + assert.are.equal(configmanager.set("*", "set_test", "testkey"), true, "Setting a nil global value"); + assert.are.equal(configmanager.set("*", "set_test", "testkey", 123), true, "Setting a global value"); + end); + end); +end); diff --git a/spec/core_moduleapi_spec.lua b/spec/core_moduleapi_spec.lua new file mode 100644 index 00000000..20431935 --- /dev/null +++ b/spec/core_moduleapi_spec.lua @@ -0,0 +1,76 @@ + +package.loaded["core.configmanager"] = {}; +package.loaded["core.statsmanager"] = {}; +package.loaded["net.server"] = {}; + +local set = require "util.set"; + +_G.prosody = { hosts = {}, core_post_stanza = true }; + +local api = require "core.moduleapi"; + +local module = setmetatable({}, {__index = api}); +local opt = nil; +function module:log() end +function module:get_option(name) + if name == "opt" then + return opt; + else + return nil; + end +end + +function test_option_value(value, returns) + opt = value; + assert(module:get_option_number("opt") == returns.number, "number doesn't match"); + assert(module:get_option_string("opt") == returns.string, "string doesn't match"); + assert(module:get_option_boolean("opt") == returns.boolean, "boolean doesn't match"); + + if type(returns.array) == "table" then + local target_array, returned_array = returns.array, module:get_option_array("opt"); + assert(#target_array == #returned_array, "array length doesn't match"); + for i=1,#target_array do + assert(target_array[i] == returned_array[i], "array item doesn't match"); + end + else + assert(module:get_option_array("opt") == returns.array, "array is returned (not nil)"); + end + + if type(returns.set) == "table" then + local target_items, returned_items = set.new(returns.set), module:get_option_set("opt"); + assert(target_items == returned_items, "set doesn't match"); + else + assert(module:get_option_set("opt") == returns.set, "set is returned (not nil)"); + end +end + +describe("core.moduleapi", function() + describe("#get_option_*()", function() + it("should handle missing options", function() + test_option_value(nil, {}); + end); + + it("should return correctly handle boolean options", function() + test_option_value(true, { boolean = true, string = "true", array = {true}, set = {true} }); + test_option_value(false, { boolean = false, string = "false", array = {false}, set = {false} }); + test_option_value("true", { boolean = true, string = "true", array = {"true"}, set = {"true"} }); + test_option_value("false", { boolean = false, string = "false", array = {"false"}, set = {"false"} }); + test_option_value(1, { boolean = true, string = "1", array = {1}, set = {1}, number = 1 }); + test_option_value(0, { boolean = false, string = "0", array = {0}, set = {0}, number = 0 }); + end); + + it("should return handle strings", function() + test_option_value("hello world", { string = "hello world", array = {"hello world"}, set = {"hello world"} }); + end); + + it("should return handle numbers", function() + test_option_value(1234, { string = "1234", number = 1234, array = {1234}, set = {1234} }); + end); + + it("should return handle arrays", function() + test_option_value({1, 2, 3}, { boolean = true, string = "1", number = 1, array = {1, 2, 3}, set = {1, 2, 3} }); + test_option_value({1, 2, 3, 3, 4}, {boolean = true, string = "1", number = 1, array = {1, 2, 3, 3, 4}, set = {1, 2, 3, 4} }); + test_option_value({0, 1, 2, 3}, { boolean = false, string = "0", number = 0, array = {0, 1, 2, 3}, set = {0, 1, 2, 3} }); + end); + end) +end) diff --git a/tests/json/fail1.json b/spec/json/fail1.json index 6216b865..6216b865 100644 --- a/tests/json/fail1.json +++ b/spec/json/fail1.json diff --git a/tests/json/fail10.json b/spec/json/fail10.json index 5d8c0047..5d8c0047 100644 --- a/tests/json/fail10.json +++ b/spec/json/fail10.json diff --git a/tests/json/fail11.json b/spec/json/fail11.json index 76eb95b4..76eb95b4 100644 --- a/tests/json/fail11.json +++ b/spec/json/fail11.json diff --git a/tests/json/fail12.json b/spec/json/fail12.json index 77580a45..77580a45 100644 --- a/tests/json/fail12.json +++ b/spec/json/fail12.json diff --git a/tests/json/fail13.json b/spec/json/fail13.json index 379406b5..379406b5 100644 --- a/tests/json/fail13.json +++ b/spec/json/fail13.json diff --git a/tests/json/fail14.json b/spec/json/fail14.json index 0ed366b3..0ed366b3 100644 --- a/tests/json/fail14.json +++ b/spec/json/fail14.json diff --git a/tests/json/fail15.json b/spec/json/fail15.json index fc8376b6..fc8376b6 100644 --- a/tests/json/fail15.json +++ b/spec/json/fail15.json diff --git a/tests/json/fail16.json b/spec/json/fail16.json index 3fe21d4b..3fe21d4b 100644 --- a/tests/json/fail16.json +++ b/spec/json/fail16.json diff --git a/tests/json/fail17.json b/spec/json/fail17.json index 62b9214a..62b9214a 100644 --- a/tests/json/fail17.json +++ b/spec/json/fail17.json diff --git a/tests/json/fail18.json b/spec/json/fail18.json index edac9271..edac9271 100644 --- a/tests/json/fail18.json +++ b/spec/json/fail18.json diff --git a/tests/json/fail19.json b/spec/json/fail19.json index 3b9c46fa..3b9c46fa 100644 --- a/tests/json/fail19.json +++ b/spec/json/fail19.json diff --git a/tests/json/fail2.json b/spec/json/fail2.json index 6b7c11e5..6b7c11e5 100644 --- a/tests/json/fail2.json +++ b/spec/json/fail2.json diff --git a/tests/json/fail20.json b/spec/json/fail20.json index 27c1af3e..27c1af3e 100644 --- a/tests/json/fail20.json +++ b/spec/json/fail20.json diff --git a/tests/json/fail21.json b/spec/json/fail21.json index 62474573..62474573 100644 --- a/tests/json/fail21.json +++ b/spec/json/fail21.json diff --git a/tests/json/fail22.json b/spec/json/fail22.json index a7752581..a7752581 100644 --- a/tests/json/fail22.json +++ b/spec/json/fail22.json diff --git a/tests/json/fail23.json b/spec/json/fail23.json index 494add1c..494add1c 100644 --- a/tests/json/fail23.json +++ b/spec/json/fail23.json diff --git a/tests/json/fail24.json b/spec/json/fail24.json index caff239b..caff239b 100644 --- a/tests/json/fail24.json +++ b/spec/json/fail24.json diff --git a/tests/json/fail25.json b/spec/json/fail25.json index 8b7ad23e..8b7ad23e 100644 --- a/tests/json/fail25.json +++ b/spec/json/fail25.json diff --git a/tests/json/fail26.json b/spec/json/fail26.json index 845d26a6..845d26a6 100644 --- a/tests/json/fail26.json +++ b/spec/json/fail26.json diff --git a/tests/json/fail27.json b/spec/json/fail27.json index 6b01a2ca..6b01a2ca 100644 --- a/tests/json/fail27.json +++ b/spec/json/fail27.json diff --git a/tests/json/fail28.json b/spec/json/fail28.json index 621a0101..621a0101 100644 --- a/tests/json/fail28.json +++ b/spec/json/fail28.json diff --git a/tests/json/fail29.json b/spec/json/fail29.json index 47ec421b..47ec421b 100644 --- a/tests/json/fail29.json +++ b/spec/json/fail29.json diff --git a/tests/json/fail3.json b/spec/json/fail3.json index 168c81eb..168c81eb 100644 --- a/tests/json/fail3.json +++ b/spec/json/fail3.json diff --git a/tests/json/fail30.json b/spec/json/fail30.json index 8ab0bc4b..8ab0bc4b 100644 --- a/tests/json/fail30.json +++ b/spec/json/fail30.json diff --git a/tests/json/fail31.json b/spec/json/fail31.json index 1cce602b..1cce602b 100644 --- a/tests/json/fail31.json +++ b/spec/json/fail31.json diff --git a/tests/json/fail32.json b/spec/json/fail32.json index 45cba739..45cba739 100644 --- a/tests/json/fail32.json +++ b/spec/json/fail32.json diff --git a/tests/json/fail33.json b/spec/json/fail33.json index ca5eb19d..ca5eb19d 100644 --- a/tests/json/fail33.json +++ b/spec/json/fail33.json diff --git a/tests/json/fail4.json b/spec/json/fail4.json index 9de168bf..9de168bf 100644 --- a/tests/json/fail4.json +++ b/spec/json/fail4.json diff --git a/tests/json/fail5.json b/spec/json/fail5.json index ddf3ce3d..ddf3ce3d 100644 --- a/tests/json/fail5.json +++ b/spec/json/fail5.json diff --git a/tests/json/fail6.json b/spec/json/fail6.json index ed91580e..ed91580e 100644 --- a/tests/json/fail6.json +++ b/spec/json/fail6.json diff --git a/tests/json/fail7.json b/spec/json/fail7.json index 8a96af3e..8a96af3e 100644 --- a/tests/json/fail7.json +++ b/spec/json/fail7.json diff --git a/tests/json/fail8.json b/spec/json/fail8.json index b28479c6..b28479c6 100644 --- a/tests/json/fail8.json +++ b/spec/json/fail8.json diff --git a/tests/json/fail9.json b/spec/json/fail9.json index 5815574f..5815574f 100644 --- a/tests/json/fail9.json +++ b/spec/json/fail9.json diff --git a/tests/json/pass1.json b/spec/json/pass1.json index 70e26854..70e26854 100644 --- a/tests/json/pass1.json +++ b/spec/json/pass1.json diff --git a/tests/json/pass2.json b/spec/json/pass2.json index d3c63c7a..d3c63c7a 100644 --- a/tests/json/pass2.json +++ b/spec/json/pass2.json diff --git a/tests/json/pass3.json b/spec/json/pass3.json index 4528d51f..4528d51f 100644 --- a/tests/json/pass3.json +++ b/spec/json/pass3.json diff --git a/spec/net_http_parser_spec.lua b/spec/net_http_parser_spec.lua new file mode 100644 index 00000000..6bba087c --- /dev/null +++ b/spec/net_http_parser_spec.lua @@ -0,0 +1,52 @@ +local httpstreams = { [[ +GET / HTTP/1.1 +Host: example.com + +]], [[ +HTTP/1.1 200 OK +Content-Length: 0 + +]], [[ +HTTP/1.1 200 OK +Content-Length: 7 + +Hello +HTTP/1.1 200 OK +Transfer-Encoding: chunked + +1 +H +1 +e +2 +ll +1 +o +0 + + +]] +} + + +local http_parser = require "net.http.parser"; + +describe("net.http.parser", function() + describe("#new()", function() + it("should work", function() + for _, stream in ipairs(httpstreams) do + local success; + local function success_cb(packet) + success = true; + end + stream = stream:gsub("\n", "\r\n"); + local parser = http_parser.new(success_cb, error, stream:sub(1,4) == "HTTP" and "client" or "server") + for chunk in stream:gmatch("..?.?") do + parser:feed(chunk); + end + + assert.is_true(success); + end + end); + end); +end); diff --git a/spec/net_http_server_spec.lua b/spec/net_http_server_spec.lua new file mode 100644 index 00000000..758b619d --- /dev/null +++ b/spec/net_http_server_spec.lua @@ -0,0 +1,13 @@ +describe("net.http.server", function () + package.loaded["net.server"] = {} + local server = require "net.http.server"; + describe("events", function () + it("should work with util.helpers", function () + -- See #1044 + server.add_handler("GET host/foo/*", function () end, 0); + server.add_handler("GET host/foo/bar", function () end, 0); + local helpers = require "util.helpers"; + assert.is.string(helpers.show_events(server._events)); + end); + end); +end); diff --git a/tests/utf8_sequences.txt b/spec/utf8_sequences.txt index 1b967b2e..1b967b2e 100644 --- a/tests/utf8_sequences.txt +++ b/spec/utf8_sequences.txt diff --git a/spec/util_async_spec.lua b/spec/util_async_spec.lua new file mode 100644 index 00000000..d2de8c94 --- /dev/null +++ b/spec/util_async_spec.lua @@ -0,0 +1,616 @@ +local async = require "util.async"; + +describe("util.async", function() + local debug = false; + local print = print; + if debug then + require "util.logger".add_simple_sink(print); + else + print = function () end + end + + local function mock_watchers(event_log) + local function generic_logging_watcher(name) + return function (...) + table.insert(event_log, { name = name, n = select("#", ...)-1, select(2, ...) }); + end; + end; + return setmetatable(mock{ + ready = generic_logging_watcher("ready"); + waiting = generic_logging_watcher("waiting"); + error = generic_logging_watcher("error"); + }, { + __index = function (_, event) + -- Unexpected watcher called + assert(false, "unexpected watcher called: "..event); + end; + }) + end + + local function new(func) + local event_log = {}; + local spy_func = spy.new(func); + return async.runner(spy_func, mock_watchers(event_log)), spy_func, event_log; + end + describe("#runner", function() + it("should work", function() + local r = new(function (item) assert(type(item) == "number") end); + r:run(1); + r:run(2); + end); + + it("should be ready after creation", function () + local r = new(function () end); + assert.equal(r.state, "ready"); + end); + + it("should do nothing if the queue is empty", function () + local did_run; + local r = new(function () did_run = true end); + r:run(); + assert.equal(r.state, "ready"); + assert.is_nil(did_run); + r:run("hello"); + assert.is_true(did_run); + end); + + it("should support queuing work items without running", function () + local did_run; + local r = new(function () did_run = true end); + r:enqueue("hello"); + assert.equal(r.state, "ready"); + assert.is_nil(did_run); + r:run(); + assert.is_true(did_run); + end); + + it("should support queuing multiple work items", function () + local last_item; + local r, s = new(function (item) last_item = item; end); + r:enqueue("hello"); + r:enqueue("there"); + r:enqueue("world"); + assert.equal(r.state, "ready"); + r:run(); + assert.equal(r.state, "ready"); + assert.spy(s).was.called(3); + assert.equal(last_item, "world"); + end); + + it("should support all simple data types", function () + local last_item; + local r, s = new(function (item) last_item = item; end); + local values = { {}, 123, "hello", true, false }; + for i = 1, #values do + r:enqueue(values[i]); + end + assert.equal(r.state, "ready"); + r:run(); + assert.equal(r.state, "ready"); + assert.spy(s).was.called(#values); + for i = 1, #values do + assert.spy(s).was.called_with(values[i]); + end + assert.equal(last_item, values[#values]); + end); + + it("should work with no parameters", function () + local item = "fail"; + local r = async.runner(); + local f = spy.new(function () item = "success"; end); + r:run(f); + assert.spy(f).was.called(); + assert.equal(item, "success"); + end); + + it("supports a default error handler", function () + local item = "fail"; + local r = async.runner(); + local f = spy.new(function () error("test error"); end); + assert.error_matches(function () + r:run(f); + end, "test error"); + assert.spy(f).was.called(); + assert.equal(item, "fail"); + end); + + describe("#errors", function () + describe("should notify", function () + local last_processed_item, last_error; + local r; + r = async.runner(function (item) + if item == "error" then + error({ e = "test error" }); + end + last_processed_item = item; + end, mock{ + ready = function () end; + waiting = function () end; + error = function (runner, err) + assert.equal(r, runner); + last_error = err; + end; + }); + + -- Simple item, no error + r:run("hello"); + assert.equal(r.state, "ready"); + assert.equal(last_processed_item, "hello"); + assert.spy(r.watchers.ready).was_not.called(); + assert.spy(r.watchers.error).was_not.called(); + + -- Trigger an error inside the runner + assert.equal(last_error, nil); + r:run("error"); + test("the correct watcher functions", function () + -- Only the error watcher should have been called + assert.spy(r.watchers.ready).was_not.called(); + assert.spy(r.watchers.waiting).was_not.called(); + assert.spy(r.watchers.error).was.called(1); + end); + test("with the correct error", function () + -- The error watcher state should be correct, to + -- demonstrate the error was passed correctly + assert.is_table(last_error); + assert.equal(last_error.e, "test error"); + last_error = nil; + end); + assert.equal(r.state, "ready"); + assert.equal(last_processed_item, "hello"); + end); + + do + local last_processed_item, last_error; + local r; + local wait, done; + r = async.runner(function (item) + if item == "error" then + error({ e = "test error" }); + elseif item == "wait" then + wait, done = async.waiter(); + wait(); + error({ e = "post wait error" }); + end + last_processed_item = item; + end, mock({ + ready = function () end; + waiting = function () end; + error = function (runner, err) + assert.equal(r, runner); + last_error = err; + end; + })); + + randomize(false); --luacheck: ignore 113/randomize + + it("should not be fatal to the runner", function () + r:run("world"); + assert.equal(r.state, "ready"); + assert.spy(r.watchers.ready).was_not.called(); + assert.equal(last_processed_item, "world"); + end); + it("should work despite a #waiter", function () + -- This test covers an important case where a runner + -- throws an error while being executed outside of the + -- main loop. This happens when it was blocked ('waiting'), + -- and then released (via a call to done()). + last_error = nil; + r:run("wait"); + assert.equal(r.state, "waiting"); + assert.spy(r.watchers.waiting).was.called(1); + done(); + -- At this point an error happens (state goes error->ready) + assert.equal(r.state, "ready"); + assert.spy(r.watchers.error).was.called(1); + assert.spy(r.watchers.ready).was.called(1); + assert.is_table(last_error); + assert.equal(last_error.e, "post wait error"); + last_error = nil; + r:run("hello again"); + assert.spy(r.watchers.ready).was.called(1); + assert.spy(r.watchers.waiting).was.called(1); + assert.spy(r.watchers.error).was.called(1); + assert.equal(r.state, "ready"); + assert.equal(last_processed_item, "hello again"); + end); + end + + it("should continue to process work items", function () + local last_item; + local runner, runner_func = new(function (item) + if item == "error" then + error("test error"); + end + last_item = item; + end); + runner:enqueue("one"); + runner:enqueue("error"); + runner:enqueue("two"); + runner:run(); + assert.equal(runner.state, "ready"); + assert.spy(runner_func).was.called(3); + assert.spy(runner.watchers.error).was.called(1); + assert.spy(runner.watchers.ready).was.called(0); + assert.spy(runner.watchers.waiting).was.called(0); + assert.equal(last_item, "two"); + end); + + it("should continue to process work items during resume", function () + local wait, done, last_item; + local runner, runner_func = new(function (item) + if item == "wait-error" then + wait, done = async.waiter(); + wait(); + error("test error"); + end + last_item = item; + end); + runner:enqueue("one"); + runner:enqueue("wait-error"); + runner:enqueue("two"); + runner:run(); + done(); + assert.equal(runner.state, "ready"); + assert.spy(runner_func).was.called(3); + assert.spy(runner.watchers.error).was.called(1); + assert.spy(runner.watchers.waiting).was.called(1); + assert.spy(runner.watchers.ready).was.called(1); + assert.equal(last_item, "two"); + end); + end); + end); + describe("#waiter", function() + it("should error outside of async context", function () + assert.has_error(function () + async.waiter(); + end); + end); + it("should work", function () + local wait, done; + + local r = new(function (item) + assert(type(item) == "number") + if item == 3 then + wait, done = async.waiter(); + wait(); + end + end); + + r:run(1); + assert(r.state == "ready"); + r:run(2); + assert(r.state == "ready"); + r:run(3); + assert(r.state == "waiting"); + done(); + assert(r.state == "ready"); + --for k, v in ipairs(l) do print(k,v) end + end); + + it("should work", function () + -------------------- + local wait, done; + local last_item = 0; + local r = new(function (item) + assert(type(item) == "number") + assert(item == last_item + 1); + last_item = item; + if item == 3 then + wait, done = async.waiter(); + wait(); + end + end); + + r:run(1); + assert(r.state == "ready"); + r:run(2); + assert(r.state == "ready"); + r:run(3); + assert(r.state == "waiting"); + r:run(4); + assert(r.state == "waiting"); + done(); + assert(r.state == "ready"); + --for k, v in ipairs(l) do print(k,v) end + end); + it("should work", function () + -------------------- + local wait, done; + local last_item = 0; + local r = new(function (item) + assert(type(item) == "number") + assert((item == last_item + 1) or item == 3); + last_item = item; + if item == 3 then + wait, done = async.waiter(); + wait(); + end + end); + + r:run(1); + assert(r.state == "ready"); + r:run(2); + assert(r.state == "ready"); + + r:run(3); + assert(r.state == "waiting"); + r:run(3); + assert(r.state == "waiting"); + r:run(3); + assert(r.state == "waiting"); + r:run(4); + assert(r.state == "waiting"); + + for i = 1, 3 do + done(); + if i < 3 then + assert(r.state == "waiting"); + end + end + + assert(r.state == "ready"); + --for k, v in ipairs(l) do print(k,v) end + end); + it("should work", function () + -------------------- + local wait, done; + local last_item = 0; + local r = new(function (item) + assert(type(item) == "number") + assert((item == last_item + 1) or item == 3); + last_item = item; + if item == 3 then + wait, done = async.waiter(); + wait(); + end + end); + + r:run(1); + assert(r.state == "ready"); + r:run(2); + assert(r.state == "ready"); + + r:run(3); + assert(r.state == "waiting"); + r:run(3); + assert(r.state == "waiting"); + + for i = 1, 2 do + done(); + if i < 2 then + assert(r.state == "waiting"); + end + end + + assert(r.state == "ready"); + r:run(4); + assert(r.state == "ready"); + + assert(r.state == "ready"); + --for k, v in ipairs(l) do print(k,v) end + end); + it("should work with multiple runners in parallel", function () + -- Now with multiple runners + -------------------- + local wait1, done1; + local last_item1 = 0; + local r1 = new(function (item) + assert(type(item) == "number") + assert((item == last_item1 + 1) or item == 3); + last_item1 = item; + if item == 3 then + wait1, done1 = async.waiter(); + wait1(); + end + end, "r1"); + + local wait2, done2; + local last_item2 = 0; + local r2 = new(function (item) + assert(type(item) == "number") + assert((item == last_item2 + 1) or item == 3); + last_item2 = item; + if item == 3 then + wait2, done2 = async.waiter(); + wait2(); + end + end, "r2"); + + r1:run(1); + assert(r1.state == "ready"); + r1:run(2); + assert(r1.state == "ready"); + + r1:run(3); + assert(r1.state == "waiting"); + r1:run(3); + assert(r1.state == "waiting"); + + r2:run(1); + assert(r1.state == "waiting"); + assert(r2.state == "ready"); + + r2:run(2); + assert(r1.state == "waiting"); + assert(r2.state == "ready"); + + r2:run(3); + assert(r1.state == "waiting"); + assert(r2.state == "waiting"); + done2(); + + r2:run(3); + assert(r1.state == "waiting"); + assert(r2.state == "waiting"); + done2(); + + r2:run(4); + assert(r1.state == "waiting"); + assert(r2.state == "ready"); + + for i = 1, 2 do + done1(); + if i < 2 then + assert(r1.state == "waiting"); + end + end + + assert(r1.state == "ready"); + r1:run(4); + assert(r1.state == "ready"); + + assert(r1.state == "ready"); + --for k, v in ipairs(l1) do print(k,v) end + end); + it("should work work with multiple runners in parallel", function () + -------------------- + local wait1, done1; + local last_item1 = 0; + local r1 = new(function (item) + print("r1 processing ", item); + assert(type(item) == "number") + assert((item == last_item1 + 1) or item == 3); + last_item1 = item; + if item == 3 then + wait1, done1 = async.waiter(); + wait1(); + end + end, "r1"); + + local wait2, done2; + local last_item2 = 0; + local r2 = new(function (item) + print("r2 processing ", item); + assert.is_number(item); + assert((item == last_item2 + 1) or item == 3); + last_item2 = item; + if item == 3 then + wait2, done2 = async.waiter(); + wait2(); + end + end, "r2"); + + r1:run(1); + assert.equal(r1.state, "ready"); + r1:run(2); + assert.equal(r1.state, "ready"); + + r1:run(5); + assert.equal(r1.state, "ready"); + + r1:run(3); + assert.equal(r1.state, "waiting"); + r1:run(5); -- Will error, when we get to it + assert.equal(r1.state, "waiting"); + done1(); + assert.equal(r1.state, "ready"); + r1:run(3); + assert.equal(r1.state, "waiting"); + + r2:run(1); + assert.equal(r1.state, "waiting"); + assert.equal(r2.state, "ready"); + + r2:run(2); + assert.equal(r1.state, "waiting"); + assert.equal(r2.state, "ready"); + + r2:run(3); + assert.equal(r1.state, "waiting"); + assert.equal(r2.state, "waiting"); + + done2(); + assert.equal(r1.state, "waiting"); + assert.equal(r2.state, "ready"); + + r2:run(3); + assert.equal(r1.state, "waiting"); + assert.equal(r2.state, "waiting"); + + done2(); + assert.equal(r1.state, "waiting"); + assert.equal(r2.state, "ready"); + + r2:run(4); + assert.equal(r1.state, "waiting"); + assert.equal(r2.state, "ready"); + + done1(); + + assert.equal(r1.state, "ready"); + r1:run(4); + assert.equal(r1.state, "ready"); + + assert.equal(r1.state, "ready"); + end); + + it("should support multiple done() calls", function () + local processed_item; + local wait, done; + local r, rf = new(function (item) + wait, done = async.waiter(4); + wait(); + processed_item = item; + end); + r:run("test"); + for _ = 1, 3 do + done(); + assert.equal(r.state, "waiting"); + assert.is_nil(processed_item); + end + done(); + assert.equal(r.state, "ready"); + assert.equal(processed_item, "test"); + assert.spy(r.watchers.error).was_not.called(); + end); + + it("should not allow done() to be called more than specified", function () + local processed_item; + local wait, done; + local r, rf = new(function (item) + wait, done = async.waiter(4); + wait(); + processed_item = item; + end); + r:run("test"); + for _ = 1, 4 do + done(); + end + assert.has_error(done); + assert.equal(r.state, "ready"); + assert.equal(processed_item, "test"); + assert.spy(r.watchers.error).was_not.called(); + end); + + it("should allow done() to be called before wait()", function () + local processed_item; + local r, rf = new(function (item) + local wait, done = async.waiter(); + done(); + wait(); + processed_item = item; + end); + r:run("test"); + assert.equal(processed_item, "test"); + assert.equal(r.state, "ready"); + -- Since the observable state did not change, + -- the watchers should not have been called + assert.spy(r.watchers.waiting).was_not.called(); + assert.spy(r.watchers.ready).was_not.called(); + end); + end); + + describe("#ready()", function () + it("should return false outside an async context", function () + assert.falsy(async.ready()); + end); + it("should return true inside an async context", function () + local r = new(function () + assert.truthy(async.ready()); + end); + r:run(true); + assert.spy(r.func).was.called(); + assert.spy(r.watchers.error).was_not.called(); + end); + end); +end); diff --git a/spec/util_cache_spec.lua b/spec/util_cache_spec.lua new file mode 100644 index 00000000..9c7d75fe --- /dev/null +++ b/spec/util_cache_spec.lua @@ -0,0 +1,316 @@ + +local cache = require "util.cache"; + +describe("util.cache", function() + describe("#new()", function() + it("should work", function() + + local c = cache.new(5); + + local function expect_kv(key, value, actual_key, actual_value) + assert.are.equal(key, actual_key, "key incorrect"); + assert.are.equal(value, actual_value, "value incorrect"); + end + + expect_kv(nil, nil, c:head()); + expect_kv(nil, nil, c:tail()); + + assert.are.equal(c:count(), 0); + + c:set("one", 1) + assert.are.equal(c:count(), 1); + expect_kv("one", 1, c:head()); + expect_kv("one", 1, c:tail()); + + c:set("two", 2) + expect_kv("two", 2, c:head()); + expect_kv("one", 1, c:tail()); + + c:set("three", 3) + expect_kv("three", 3, c:head()); + expect_kv("one", 1, c:tail()); + + c:set("four", 4) + c:set("five", 5); + assert.are.equal(c:count(), 5); + expect_kv("five", 5, c:head()); + expect_kv("one", 1, c:tail()); + + c:set("foo", nil); + assert.are.equal(c:count(), 5); + expect_kv("five", 5, c:head()); + expect_kv("one", 1, c:tail()); + + assert.are.equal(c:get("one"), 1); + expect_kv("five", 5, c:head()); + expect_kv("one", 1, c:tail()); + + assert.are.equal(c:get("two"), 2); + assert.are.equal(c:get("three"), 3); + assert.are.equal(c:get("four"), 4); + assert.are.equal(c:get("five"), 5); + + assert.are.equal(c:get("foo"), nil); + assert.are.equal(c:get("bar"), nil); + + c:set("six", 6); + assert.are.equal(c:count(), 5); + expect_kv("six", 6, c:head()); + expect_kv("two", 2, c:tail()); + + assert.are.equal(c:get("one"), nil); + assert.are.equal(c:get("two"), 2); + assert.are.equal(c:get("three"), 3); + assert.are.equal(c:get("four"), 4); + assert.are.equal(c:get("five"), 5); + assert.are.equal(c:get("six"), 6); + + c:set("three", nil); + assert.are.equal(c:count(), 4); + + assert.are.equal(c:get("one"), nil); + assert.are.equal(c:get("two"), 2); + assert.are.equal(c:get("three"), nil); + assert.are.equal(c:get("four"), 4); + assert.are.equal(c:get("five"), 5); + assert.are.equal(c:get("six"), 6); + + c:set("seven", 7); + assert.are.equal(c:count(), 5); + + assert.are.equal(c:get("one"), nil); + assert.are.equal(c:get("two"), 2); + assert.are.equal(c:get("three"), nil); + assert.are.equal(c:get("four"), 4); + assert.are.equal(c:get("five"), 5); + assert.are.equal(c:get("six"), 6); + assert.are.equal(c:get("seven"), 7); + + c:set("eight", 8); + assert.are.equal(c:count(), 5); + + assert.are.equal(c:get("one"), nil); + assert.are.equal(c:get("two"), nil); + assert.are.equal(c:get("three"), nil); + assert.are.equal(c:get("four"), 4); + assert.are.equal(c:get("five"), 5); + assert.are.equal(c:get("six"), 6); + assert.are.equal(c:get("seven"), 7); + assert.are.equal(c:get("eight"), 8); + + c:set("four", 4); + assert.are.equal(c:count(), 5); + + assert.are.equal(c:get("one"), nil); + assert.are.equal(c:get("two"), nil); + assert.are.equal(c:get("three"), nil); + assert.are.equal(c:get("four"), 4); + assert.are.equal(c:get("five"), 5); + assert.are.equal(c:get("six"), 6); + assert.are.equal(c:get("seven"), 7); + assert.are.equal(c:get("eight"), 8); + + c:set("nine", 9); + assert.are.equal(c:count(), 5); + + assert.are.equal(c:get("one"), nil); + assert.are.equal(c:get("two"), nil); + assert.are.equal(c:get("three"), nil); + assert.are.equal(c:get("four"), 4); + assert.are.equal(c:get("five"), nil); + assert.are.equal(c:get("six"), 6); + assert.are.equal(c:get("seven"), 7); + assert.are.equal(c:get("eight"), 8); + assert.are.equal(c:get("nine"), 9); + + do + local keys = { "nine", "four", "eight", "seven", "six" }; + local values = { 9, 4, 8, 7, 6 }; + local i = 0; + for k, v in c:items() do + i = i + 1; + assert.are.equal(k, keys[i]); + assert.are.equal(v, values[i]); + end + assert.are.equal(i, 5); + + c:set("four", "2+2"); + assert.are.equal(c:count(), 5); + + assert.are.equal(c:get("one"), nil); + assert.are.equal(c:get("two"), nil); + assert.are.equal(c:get("three"), nil); + assert.are.equal(c:get("four"), "2+2"); + assert.are.equal(c:get("five"), nil); + assert.are.equal(c:get("six"), 6); + assert.are.equal(c:get("seven"), 7); + assert.are.equal(c:get("eight"), 8); + assert.are.equal(c:get("nine"), 9); + end + + do + local keys = { "four", "nine", "eight", "seven", "six" }; + local values = { "2+2", 9, 8, 7, 6 }; + local i = 0; + for k, v in c:items() do + i = i + 1; + assert.are.equal(k, keys[i]); + assert.are.equal(v, values[i]); + end + assert.are.equal(i, 5); + + c:set("foo", nil); + assert.are.equal(c:count(), 5); + + assert.are.equal(c:get("one"), nil); + assert.are.equal(c:get("two"), nil); + assert.are.equal(c:get("three"), nil); + assert.are.equal(c:get("four"), "2+2"); + assert.are.equal(c:get("five"), nil); + assert.are.equal(c:get("six"), 6); + assert.are.equal(c:get("seven"), 7); + assert.are.equal(c:get("eight"), 8); + assert.are.equal(c:get("nine"), 9); + end + + do + local keys = { "four", "nine", "eight", "seven", "six" }; + local values = { "2+2", 9, 8, 7, 6 }; + local i = 0; + for k, v in c:items() do + i = i + 1; + assert.are.equal(k, keys[i]); + assert.are.equal(v, values[i]); + end + assert.are.equal(i, 5); + + c:set("four", nil); + + assert.are.equal(c:get("one"), nil); + assert.are.equal(c:get("two"), nil); + assert.are.equal(c:get("three"), nil); + assert.are.equal(c:get("four"), nil); + assert.are.equal(c:get("five"), nil); + assert.are.equal(c:get("six"), 6); + assert.are.equal(c:get("seven"), 7); + assert.are.equal(c:get("eight"), 8); + assert.are.equal(c:get("nine"), 9); + end + + do + local keys = { "nine", "eight", "seven", "six" }; + local values = { 9, 8, 7, 6 }; + local i = 0; + for k, v in c:items() do + i = i + 1; + assert.are.equal(k, keys[i]); + assert.are.equal(v, values[i]); + end + assert.are.equal(i, 4); + end + + do + local evicted_key, evicted_value; + local c2 = cache.new(3, function (_key, _value) + evicted_key, evicted_value = _key, _value; + end); + local function set(k, v, should_evict_key, should_evict_value) + evicted_key, evicted_value = nil, nil; + c2:set(k, v); + assert.are.equal(evicted_key, should_evict_key); + assert.are.equal(evicted_value, should_evict_value); + end + set("a", 1) + set("a", 1) + set("a", 1) + set("a", 1) + set("a", 1) + + set("b", 2) + set("c", 3) + set("b", 2) + set("d", 4, "a", 1) + set("e", 5, "c", 3) + end + + do + local evicted_key, evicted_value; + local c3 = cache.new(1, function (_key, _value) + evicted_key, evicted_value = _key, _value; + if _key == "a" then + -- Sanity check for what we're evicting + assert.are.equal(_key, "a"); + assert.are.equal(_value, 1); + -- We're going to block eviction of this key/value, so set to nil... + evicted_key, evicted_value = nil, nil; + -- Returning false to block eviction + return false + end + end); + local function set(k, v, should_evict_key, should_evict_value) + evicted_key, evicted_value = nil, nil; + local ret = c3:set(k, v); + assert.are.equal(evicted_key, should_evict_key); + assert.are.equal(evicted_value, should_evict_value); + return ret; + end + set("a", 1) + set("a", 1) + set("a", 1) + set("a", 1) + set("a", 1) + + -- Our on_evict prevents "a" from being evicted, causing this to fail... + assert.are.equal(set("b", 2), false, "Failed to prevent eviction, or signal result"); + + expect_kv("a", 1, c3:head()); + expect_kv("a", 1, c3:tail()); + + -- Check the final state is what we expect + assert.are.equal(c3:get("a"), 1); + assert.are.equal(c3:get("b"), nil); + assert.are.equal(c3:count(), 1); + end + + + local c4 = cache.new(3, false); + + assert.are.equal(c4:set("a", 1), true); + assert.are.equal(c4:set("a", 1), true); + assert.are.equal(c4:set("a", 1), true); + assert.are.equal(c4:set("a", 1), true); + assert.are.equal(c4:set("b", 2), true); + assert.are.equal(c4:set("c", 3), true); + assert.are.equal(c4:set("d", 4), false); + assert.are.equal(c4:set("d", 4), false); + assert.are.equal(c4:set("d", 4), false); + + expect_kv("c", 3, c4:head()); + expect_kv("a", 1, c4:tail()); + + local c5 = cache.new(3, function (k, v) + if k == "a" then + return nil; + elseif k == "b" then + return true; + end + return false; + end); + + assert.are.equal(c5:set("a", 1), true); + assert.are.equal(c5:set("a", 1), true); + assert.are.equal(c5:set("a", 1), true); + assert.are.equal(c5:set("a", 1), true); + assert.are.equal(c5:set("b", 2), true); + assert.are.equal(c5:set("c", 3), true); + assert.are.equal(c5:set("d", 4), true); -- "a" evicted (cb returned nil) + assert.are.equal(c5:set("d", 4), true); -- nop + assert.are.equal(c5:set("d", 4), true); -- nop + assert.are.equal(c5:set("e", 5), true); -- "b" evicted (cb returned true) + assert.are.equal(c5:set("f", 6), false); -- "c" won't evict (cb returned false) + + expect_kv("e", 5, c5:head()); + expect_kv("c", 3, c5:tail()); + end); + end); +end); diff --git a/spec/util_dataforms_spec.lua b/spec/util_dataforms_spec.lua new file mode 100644 index 00000000..56751041 --- /dev/null +++ b/spec/util_dataforms_spec.lua @@ -0,0 +1,314 @@ +local dataforms = require "util.dataforms"; +local st = require "util.stanza"; +local jid = require "util.jid"; +local iter = require "util.iterators"; + +describe("util.dataforms", function () + local some_form, xform; + setup(function () + some_form = dataforms.new({ + title = "form-title", + instructions = "form-instructions", + { + type = "hidden", + name = "FORM_TYPE", + value = "xmpp:prosody.im/spec/util.dataforms#1", + }; + { + type = "boolean", + label = "boolean-label", + name = "boolean-field", + value = true, + }, + { + type = "fixed", + label = "fixed-label", + name = "fixed-field", + value = "fixed-value", + }, + { + type = "hidden", + label = "hidden-label", + name = "hidden-field", + value = "hidden-value", + }, + { + type = "jid-multi", + label = "jid-multi-label", + name = "jid-multi-field", + value = { + "jid@multi/value#1", + "jid@multi/value#2", + }, + }, + { + type = "jid-single", + label = "jid-single-label", + name = "jid-single-field", + value = "jid@single/value", + }, + { + type = "list-multi", + label = "list-multi-label", + name = "list-multi-field", + value = { + "list-multi-option-value#1", + "list-multi-option-value#3", + }, + options = { + { + label = "list-multi-option-label#1", + value = "list-multi-option-value#1", + default = true, + }, + { + label = "list-multi-option-label#2", + value = "list-multi-option-value#2", + default = false, + }, + { + label = "list-multi-option-label#3", + value = "list-multi-option-value#3", + default = true, + }, + } + }, + { + type = "list-single", + label = "list-single-label", + name = "list-single-field", + value = "list-single-value", + options = { + "list-single-value", + "list-single-value#2", + "list-single-value#3", + } + }, + { + type = "text-multi", + label = "text-multi-label", + name = "text-multi-field", + value = "text\nmulti\nvalue", + }, + { + type = "text-private", + label = "text-private-label", + name = "text-private-field", + value = "text-private-value", + }, + { + type = "text-single", + label = "text-single-label", + name = "text-single-field", + value = "text-single-value", + }, + }); + xform = some_form:form(); + end); + + it("works", function () + assert.truthy(xform); + assert.truthy(st.is_stanza(xform)); + assert.equal("x", xform.name); + assert.equal("jabber:x:data", xform.attr.xmlns); + assert.equal("FORM_TYPE", xform:find("field@var")); + assert.equal("xmpp:prosody.im/spec/util.dataforms#1", xform:find("field/value#")); + local allowed_direct_children = { + title = true, + instructions = true, + field = true, + } + for tag in xform:childtags() do + assert.truthy(allowed_direct_children[tag.name], "unknown direct child"); + end + end); + + it("produced boolean field correctly", function () + local f; + for field in xform:childtags("field") do + if field.attr.var == "boolean-field" then + f = field; + break; + end + end + + assert.truthy(st.is_stanza(f)); + assert.equal("boolean-field", f.attr.var); + assert.equal("boolean", f.attr.type); + assert.equal("boolean-label", f.attr.label); + assert.equal(1, iter.count(f:childtags("value"))); + local val = f:get_child_text("value"); + assert.truthy(val == "true" or val == "1"); + end); + + it("produced fixed field correctly", function () + local f; + for field in xform:childtags("field") do + if field.attr.var == "fixed-field" then + f = field; + break; + end + end + + assert.truthy(st.is_stanza(f)); + assert.equal("fixed-field", f.attr.var); + assert.equal("fixed", f.attr.type); + assert.equal("fixed-label", f.attr.label); + assert.equal(1, iter.count(f:childtags("value"))); + assert.equal("fixed-value", f:get_child_text("value")); + end); + + it("produced hidden field correctly", function () + local f; + for field in xform:childtags("field") do + if field.attr.var == "hidden-field" then + f = field; + break; + end + end + + assert.truthy(st.is_stanza(f)); + assert.equal("hidden-field", f.attr.var); + assert.equal("hidden", f.attr.type); + assert.equal("hidden-label", f.attr.label); + assert.equal(1, iter.count(f:childtags("value"))); + assert.equal("hidden-value", f:get_child_text("value")); + end); + + it("produced jid-multi field correctly", function () + local f; + for field in xform:childtags("field") do + if field.attr.var == "jid-multi-field" then + f = field; + break; + end + end + + assert.truthy(st.is_stanza(f)); + assert.equal("jid-multi-field", f.attr.var); + assert.equal("jid-multi", f.attr.type); + assert.equal("jid-multi-label", f.attr.label); + assert.equal(2, iter.count(f:childtags("value"))); + + local i = 0; + for value in f:childtags("value") do + i = i + 1; + assert.equal(("jid@multi/value#%d"):format(i), value:get_text()); + end + end); + + it("produced jid-single field correctly", function () + local f; + for field in xform:childtags("field") do + if field.attr.var == "jid-single-field" then + f = field; + break; + end + end + + assert.truthy(st.is_stanza(f)); + assert.equal("jid-single-field", f.attr.var); + assert.equal("jid-single", f.attr.type); + assert.equal("jid-single-label", f.attr.label); + assert.equal(1, iter.count(f:childtags("value"))); + assert.equal("jid@single/value", f:get_child_text("value")); + assert.truthy(jid.prep(f:get_child_text("value"))); + end); + + it("produced list-multi field correctly", function () + local f; + for field in xform:childtags("field") do + if field.attr.var == "list-multi-field" then + f = field; + break; + end + end + + assert.truthy(st.is_stanza(f)); + assert.equal("list-multi-field", f.attr.var); + assert.equal("list-multi", f.attr.type); + assert.equal("list-multi-label", f.attr.label); + assert.equal(2, iter.count(f:childtags("value"))); + assert.equal("list-multi-option-value#1", f:get_child_text("value")); + assert.equal(3, iter.count(f:childtags("option"))); + end); + + it("produced list-single field correctly", function () + local f; + for field in xform:childtags("field") do + if field.attr.var == "list-single-field" then + f = field; + break; + end + end + + assert.truthy(st.is_stanza(f)); + assert.equal("list-single-field", f.attr.var); + assert.equal("list-single", f.attr.type); + assert.equal("list-single-label", f.attr.label); + assert.equal(1, iter.count(f:childtags("value"))); + assert.equal("list-single-value", f:get_child_text("value")); + assert.equal(3, iter.count(f:childtags("option"))); + end); + + it("produced text-multi field correctly", function () + local f; + for field in xform:childtags("field") do + if field.attr.var == "text-multi-field" then + f = field; + break; + end + end + + assert.truthy(st.is_stanza(f)); + assert.equal("text-multi-field", f.attr.var); + assert.equal("text-multi", f.attr.type); + assert.equal("text-multi-label", f.attr.label); + assert.equal(3, iter.count(f:childtags("value"))); + end); + + it("produced text-private field correctly", function () + local f; + for field in xform:childtags("field") do + if field.attr.var == "text-private-field" then + f = field; + break; + end + end + + assert.truthy(st.is_stanza(f)); + assert.equal("text-private-field", f.attr.var); + assert.equal("text-private", f.attr.type); + assert.equal("text-private-label", f.attr.label); + assert.equal(1, iter.count(f:childtags("value"))); + assert.equal("text-private-value", f:get_child_text("value")); + end); + + it("produced text-single field correctly", function () + local f; + for field in xform:childtags("field") do + if field.attr.var == "text-single-field" then + f = field; + break; + end + end + + assert.truthy(st.is_stanza(f)); + assert.equal("text-single-field", f.attr.var); + assert.equal("text-single", f.attr.type); + assert.equal("text-single-label", f.attr.label); + assert.equal(1, iter.count(f:childtags("value"))); + assert.equal("text-single-value", f:get_child_text("value")); + end); + + describe("get_type()", function () + it("identifes dataforms", function () + assert.equal(nil, dataforms.get_type(nil)); + assert.equal(nil, dataforms.get_type("")); + assert.equal(nil, dataforms.get_type({})); + assert.equal(nil, dataforms.get_type(st.stanza("no-a-form"))); + assert.equal("xmpp:prosody.im/spec/util.dataforms#1", dataforms.get_type(xform)); + end); + end); +end); + diff --git a/spec/util_datetime_spec.lua b/spec/util_datetime_spec.lua new file mode 100644 index 00000000..497ab7d3 --- /dev/null +++ b/spec/util_datetime_spec.lua @@ -0,0 +1,76 @@ +local util_datetime = require "util.datetime"; + +describe("util.datetime", function () + it("should have been loaded", function () + assert.is_table(util_datetime); + end); + describe("#date", function () + local date = util_datetime.date; + it("should exist", function () + assert.is_function(date); + end); + it("should return a string", function () + assert.is_string(date()); + end); + it("should look like a date", function () + assert.truthy(string.find(date(), "^%d%d%d%d%-%d%d%-%d%d$")); + end); + it("should work", function () + assert.equals(date(1136239445), "2006-01-02"); + end); + end); + describe("#time", function () + local time = util_datetime.time; + it("should exist", function () + assert.is_function(time); + end); + it("should return a string", function () + assert.is_string(time()); + end); + it("should look like a timestamp", function () + -- Note: Sub-second precision and timezones are ignored + assert.truthy(string.find(time(), "^%d%d:%d%d:%d%d")); + end); + it("should work", function () + assert.equals(time(1136239445), "22:04:05"); + end); + end); + describe("#datetime", function () + local datetime = util_datetime.datetime; + it("should exist", function () + assert.is_function(datetime); + end); + it("should return a string", function () + assert.is_string(datetime()); + end); + it("should look like a timestamp", function () + -- Note: Sub-second precision and timezones are ignored + assert.truthy(string.find(datetime(), "^%d%d%d%d%-%d%d%-%d%dT%d%d:%d%d:%d%d")); + end); + it("should work", function () + assert.equals(datetime(1136239445), "2006-01-02T22:04:05Z"); + end); + end); + describe("#legacy", function () + local legacy = util_datetime.legacy; + it("should exist", function () + assert.is_function(legacy); + end); + end); + describe("#parse", function () + local parse = util_datetime.parse; + it("should exist", function () + assert.is_function(parse); + end); + it("should work", function () + -- Timestamp used by Go + assert.equals(parse("2017-11-19T17:58:13Z"), 1511114293); + assert.equals(parse("2017-11-19T18:58:50+0100"), 1511114330); + assert.equals(parse("2006-01-02T15:04:05-0700"), 1136239445); + end); + it("should handle timezones", function () + -- https://xmpp.org/extensions/xep-0082.html#example-2 and 3 + assert.equals(parse("1969-07-21T02:56:15Z"), parse("1969-07-20T21:56:15-05:00")); + end); + end); +end); diff --git a/spec/util_encodings_spec.lua b/spec/util_encodings_spec.lua new file mode 100644 index 00000000..0f4fc2b7 --- /dev/null +++ b/spec/util_encodings_spec.lua @@ -0,0 +1,41 @@ + +local encodings = require "util.encodings"; +local utf8 = assert(encodings.utf8, "no encodings.utf8 module"); + +describe("util.encodings", function () + describe("#encode()", function() + it("should work", function () + assert.is.equal(encodings.base64.encode(""), ""); + assert.is.equal(encodings.base64.encode('coucou'), "Y291Y291"); + assert.is.equal(encodings.base64.encode("\0\0\0"), "AAAA"); + assert.is.equal(encodings.base64.encode("\255\255\255"), "////"); + end); + end); + describe("#decode()", function() + it("should work", function () + assert.is.equal(encodings.base64.decode(""), ""); + assert.is.equal(encodings.base64.decode("="), ""); + assert.is.equal(encodings.base64.decode('Y291Y291'), "coucou"); + assert.is.equal(encodings.base64.decode("AAAA"), "\0\0\0"); + assert.is.equal(encodings.base64.decode("////"), "\255\255\255"); + end); + end); +end); +describe("util.encodings.utf8", function() + describe("#valid()", function() + it("should work", function() + + for line in io.lines("spec/utf8_sequences.txt") do + local data = line:match(":%s*([^#]+)"):gsub("%s+", ""):gsub("..", function (c) return string.char(tonumber(c, 16)); end) + local expect = line:match("(%S+):"); + + assert(expect == "pass" or expect == "fail", "unknown expectation: "..line:match("^[^:]+")); + + local valid = utf8.valid(data); + assert.is.equal(valid, utf8.valid(data.." ")); + assert.is.equal(valid, expect == "pass", line); + end + + end); + end); +end); diff --git a/spec/util_events_spec.lua b/spec/util_events_spec.lua new file mode 100644 index 00000000..fee60f8f --- /dev/null +++ b/spec/util_events_spec.lua @@ -0,0 +1,212 @@ +local events = require "util.events"; + +describe("util.events", function () + it("should export a new() function", function () + assert.is_function(events.new); + end); + describe("new()", function () + it("should return return a new events object", function () + local e = events.new(); + assert.is_function(e.add_handler); + assert.is_function(e.remove_handler); + end); + end); + + local e, h; + + + describe("API", function () + before_each(function () + e = events.new(); + h = spy.new(function () end); + end); + + it("should call handlers when an event is fired", function () + e.add_handler("myevent", h); + e.fire_event("myevent"); + assert.spy(h).was_called(); + end); + + it("should not call handlers when a different event is fired", function () + e.add_handler("myevent", h); + e.fire_event("notmyevent"); + assert.spy(h).was_not_called(); + end); + + it("should pass the data argument to handlers", function () + e.add_handler("myevent", h); + e.fire_event("myevent", "mydata"); + assert.spy(h).was_called_with("mydata"); + end); + + it("should support non-string events", function () + local myevent = {}; + e.add_handler(myevent, h); + e.fire_event(myevent, "mydata"); + assert.spy(h).was_called_with("mydata"); + end); + + it("should call handlers in priority order", function () + local data = {}; + e.add_handler("myevent", function () table.insert(data, "h1"); end, 5); + e.add_handler("myevent", function () table.insert(data, "h2"); end, 3); + e.add_handler("myevent", function () table.insert(data, "h3"); end); + e.fire_event("myevent", "mydata"); + assert.same(data, { "h1", "h2", "h3" }); + end); + + it("should support non-integer priority values", function () + local data = {}; + e.add_handler("myevent", function () table.insert(data, "h1"); end, 1); + e.add_handler("myevent", function () table.insert(data, "h2"); end, 0.5); + e.add_handler("myevent", function () table.insert(data, "h3"); end, 0.25); + e.fire_event("myevent", "mydata"); + assert.same(data, { "h1", "h2", "h3" }); + end); + + it("should support negative priority values", function () + local data = {}; + e.add_handler("myevent", function () table.insert(data, "h1"); end, 1); + e.add_handler("myevent", function () table.insert(data, "h2"); end, 0); + e.add_handler("myevent", function () table.insert(data, "h3"); end, -1); + e.fire_event("myevent", "mydata"); + assert.same(data, { "h1", "h2", "h3" }); + end); + + it("should support removing handlers", function () + e.add_handler("myevent", h); + e.fire_event("myevent"); + e.remove_handler("myevent", h); + e.fire_event("myevent"); + assert.spy(h).was_called(1); + end); + + it("should support adding multiple handlers at the same time", function () + local ht = { + myevent1 = spy.new(function () end); + myevent2 = spy.new(function () end); + myevent3 = spy.new(function () end); + }; + e.add_handlers(ht); + e.fire_event("myevent1"); + e.fire_event("myevent2"); + assert.spy(ht.myevent1).was_called(); + assert.spy(ht.myevent2).was_called(); + assert.spy(ht.myevent3).was_not_called(); + end); + + it("should support removing multiple handlers at the same time", function () + local ht = { + myevent1 = spy.new(function () end); + myevent2 = spy.new(function () end); + myevent3 = spy.new(function () end); + }; + e.add_handlers(ht); + e.remove_handlers(ht); + e.fire_event("myevent1"); + e.fire_event("myevent2"); + assert.spy(ht.myevent1).was_not_called(); + assert.spy(ht.myevent2).was_not_called(); + assert.spy(ht.myevent3).was_not_called(); + end); + + pending("should support adding handlers within an event handler") + pending("should support removing handlers within an event handler") + + it("should support getting the current handlers for an event", function () + e.add_handler("myevent", h); + local handlers = e.get_handlers("myevent"); + assert.equal(h, handlers[1]); + end); + + describe("wrappers", function () + local w + before_each(function () + w = spy.new(function (handlers, event_name, event_data) + assert.is_function(handlers); + assert.equal("myevent", event_name) + assert.equal("abc", event_data); + return handlers(event_name, event_data); + end); + end); + + it("should get called", function () + e.add_wrapper("myevent", w); + e.add_handler("myevent", h); + e.fire_event("myevent", "abc"); + assert.spy(w).was_called(1); + assert.spy(h).was_called(1); + end); + + it("should be removable", function () + e.add_wrapper("myevent", w); + e.add_handler("myevent", h); + e.fire_event("myevent", "abc"); + e.remove_wrapper("myevent", w); + e.fire_event("myevent", "abc"); + assert.spy(w).was_called(1); + assert.spy(h).was_called(2); + end); + + it("should allow multiple wrappers", function () + local w2 = spy.new(function (handlers, event_name, event_data) + return handlers(event_name, event_data); + end); + e.add_wrapper("myevent", w); + e.add_handler("myevent", h); + e.add_wrapper("myevent", w2); + e.fire_event("myevent", "abc"); + e.remove_wrapper("myevent", w); + e.fire_event("myevent", "abc"); + assert.spy(w).was_called(1); + assert.spy(w2).was_called(2); + assert.spy(h).was_called(2); + end); + + it("should support a mix of global and event wrappers", function () + local w2 = spy.new(function (handlers, event_name, event_data) + return handlers(event_name, event_data); + end); + e.add_wrapper(false, w); + e.add_handler("myevent", h); + e.add_wrapper("myevent", w2); + e.fire_event("myevent", "abc"); + e.remove_wrapper(false, w); + e.fire_event("myevent", "abc"); + assert.spy(w).was_called(1); + assert.spy(w2).was_called(2); + assert.spy(h).was_called(2); + end); + end); + + describe("global wrappers", function () + local w + before_each(function () + w = spy.new(function (handlers, event_name, event_data) + assert.is_function(handlers); + assert.equal("myevent", event_name) + assert.equal("abc", event_data); + return handlers(event_name, event_data); + end); + end); + + it("should get called", function () + e.add_wrapper(false, w); + e.add_handler("myevent", h); + e.fire_event("myevent", "abc"); + assert.spy(w).was_called(1); + assert.spy(h).was_called(1); + end); + + it("should be removable", function () + e.add_wrapper(false, w); + e.add_handler("myevent", h); + e.fire_event("myevent", "abc"); + e.remove_wrapper(false, w); + e.fire_event("myevent", "abc"); + assert.spy(w).was_called(1); + assert.spy(h).was_called(2); + end); + end); + end); +end); diff --git a/spec/util_format_spec.lua b/spec/util_format_spec.lua new file mode 100644 index 00000000..7e6a0c6e --- /dev/null +++ b/spec/util_format_spec.lua @@ -0,0 +1,14 @@ +local format = require "util.format".format; + +describe("util.format", function() + describe("#format()", function() + it("should work", function() + assert.equal("hello", format("%s", "hello")); + assert.equal("<nil>", format("%s")); + assert.equal(" [<nil>]", format("", nil)); + assert.equal("true", format("%s", true)); + assert.equal("[true]", format("%d", true)); + assert.equal("% [true]", format("%%", true)); + end); + end); +end); diff --git a/spec/util_http_spec.lua b/spec/util_http_spec.lua new file mode 100644 index 00000000..bacfcfb5 --- /dev/null +++ b/spec/util_http_spec.lua @@ -0,0 +1,64 @@ + +local http = require "util.http"; + +describe("util.http", function() + describe("#urlencode()", function() + it("should not change normal characters", function() + assert.are.equal(http.urlencode("helloworld123"), "helloworld123"); + end); + + it("should escape spaces", function() + assert.are.equal(http.urlencode("hello world"), "hello%20world"); + end); + + it("should escape important URL characters", function() + assert.are.equal(http.urlencode("This & that = something"), "This%20%26%20that%20%3d%20something"); + end); + end); + + describe("#urldecode()", function() + it("should not change normal characters", function() + assert.are.equal("helloworld123", http.urldecode("helloworld123"), "Normal characters not escaped"); + end); + + it("should decode spaces", function() + assert.are.equal("hello world", http.urldecode("hello%20world"), "Spaces escaped"); + end); + + it("should decode important URL characters", function() + assert.are.equal("This & that = something", http.urldecode("This%20%26%20that%20%3d%20something"), "Important URL chars escaped"); + end); + end); + + describe("#formencode()", function() + it("should encode basic data", function() + assert.are.equal(http.formencode({ { name = "one", value = "1"}, { name = "two", value = "2" } }), "one=1&two=2", "Form encoded"); + end); + + it("should encode special characters with escaping", function() + assert.are.equal(http.formencode({ { name = "one two", value = "1"}, { name = "two one&", value = "2" } }), "one+two=1&two+one%26=2", "Form encoded"); + end); + end); + + describe("#formdecode()", function() + it("should decode basic data", function() + local t = http.formdecode("one=1&two=2"); + assert.are.same(t, { + { name = "one", value = "1" }; + { name = "two", value = "2" }; + one = "1"; + two = "2"; + }); + end); + + it("should decode special characters", function() + local t = http.formdecode("one+two=1&two+one%26=2"); + assert.are.same(t, { + { name = "one two", value = "1" }; + { name = "two one&", value = "2" }; + ["one two"] = "1"; + ["two one&"] = "2"; + }); + end); + end); +end); diff --git a/spec/util_ip_spec.lua b/spec/util_ip_spec.lua new file mode 100644 index 00000000..be5e4cff --- /dev/null +++ b/spec/util_ip_spec.lua @@ -0,0 +1,103 @@ + +local ip = require "util.ip"; + +local new_ip = ip.new_ip; +local match = ip.match; +local parse_cidr = ip.parse_cidr; +local commonPrefixLength = ip.commonPrefixLength; + +describe("util.ip", function() + describe("#match()", function() + it("should work", function() + local _ = new_ip; + local ip = _"10.20.30.40"; + assert.are.equal(match(ip, _"10.0.0.0", 8), true); + assert.are.equal(match(ip, _"10.0.0.0", 16), false); + assert.are.equal(match(ip, _"10.0.0.0", 24), false); + assert.are.equal(match(ip, _"10.0.0.0", 32), false); + + assert.are.equal(match(ip, _"10.20.0.0", 8), true); + assert.are.equal(match(ip, _"10.20.0.0", 16), true); + assert.are.equal(match(ip, _"10.20.0.0", 24), false); + assert.are.equal(match(ip, _"10.20.0.0", 32), false); + + assert.are.equal(match(ip, _"0.0.0.0", 32), false); + assert.are.equal(match(ip, _"0.0.0.0", 0), true); + assert.are.equal(match(ip, _"0.0.0.0"), false); + + assert.are.equal(match(ip, _"10.0.0.0", 255), false, "excessive number of bits"); + assert.are.equal(match(ip, _"10.0.0.0", -8), true, "negative number of bits"); + assert.are.equal(match(ip, _"10.0.0.0", -32), true, "negative number of bits"); + assert.are.equal(match(ip, _"10.0.0.0", 0), true, "zero bits"); + assert.are.equal(match(ip, _"10.0.0.0"), false, "no specified number of bits (differing ip)"); + assert.are.equal(match(ip, _"10.20.30.40"), true, "no specified number of bits (same ip)"); + + assert.are.equal(match(_"127.0.0.1", _"127.0.0.1"), true, "simple ip"); + + assert.are.equal(match(_"8.8.8.8", _"8.8.0.0", 16), true); + assert.are.equal(match(_"8.8.4.4", _"8.8.0.0", 16), true); + end); + end); + + describe("#parse_cidr()", function() + it("should work", function() + assert.are.equal(new_ip"0.0.0.0", new_ip"0.0.0.0") + + local function assert_cidr(cidr, ip, bits) + local parsed_ip, parsed_bits = parse_cidr(cidr); + assert.are.equal(new_ip(ip), parsed_ip, cidr.." parsed ip is "..ip); + assert.are.equal(bits, parsed_bits, cidr.." parsed bits is "..tostring(bits)); + end + assert_cidr("0.0.0.0", "0.0.0.0", nil); + assert_cidr("127.0.0.1", "127.0.0.1", nil); + assert_cidr("127.0.0.1/0", "127.0.0.1", 0); + assert_cidr("127.0.0.1/8", "127.0.0.1", 8); + assert_cidr("127.0.0.1/32", "127.0.0.1", 32); + assert_cidr("127.0.0.1/256", "127.0.0.1", 256); + assert_cidr("::/48", "::", 48); + end); + end); + + describe("#new_ip()", function() + it("should work", function() + local v4, v6 = "IPv4", "IPv6"; + local function assert_proto(s, proto) + local ip = new_ip(s); + if proto then + assert.are.equal(ip and ip.proto, proto, "protocol is correct for "..("%q"):format(s)); + else + assert.are.equal(ip, nil, "address is invalid"); + end + end + assert_proto("127.0.0.1", v4); + assert_proto("::1", v6); + assert_proto("", nil); + assert_proto("abc", nil); + assert_proto(" ", nil); + end); + end); + + describe("#commonPrefixLength()", function() + it("should work", function() + local function assert_cpl6(a, b, len, v4) + local ipa, ipb = new_ip(a), new_ip(b); + if v4 then len = len+96; end + assert.are.equal(commonPrefixLength(ipa, ipb), len, "common prefix length of "..a.." and "..b.." is "..len); + assert.are.equal(commonPrefixLength(ipb, ipa), len, "common prefix length of "..b.." and "..a.." is "..len); + end + local function assert_cpl4(a, b, len) + return assert_cpl6(a, b, len, "IPv4"); + end + assert_cpl4("0.0.0.0", "0.0.0.0", 32); + assert_cpl4("255.255.255.255", "0.0.0.0", 0); + assert_cpl4("255.255.255.255", "255.255.0.0", 16); + assert_cpl4("255.255.255.255", "255.255.255.255", 32); + assert_cpl4("255.255.255.255", "255.255.255.255", 32); + + assert_cpl6("::1", "::1", 128); + assert_cpl6("abcd::1", "abcd::1", 128); + assert_cpl6("abcd::abcd", "abcd::", 112); + assert_cpl6("abcd::abcd", "abcd::abcd:abcd", 96); + end); + end); +end); diff --git a/spec/util_iterators_spec.lua b/spec/util_iterators_spec.lua new file mode 100644 index 00000000..d00058f4 --- /dev/null +++ b/spec/util_iterators_spec.lua @@ -0,0 +1,14 @@ +local iter = require "util.iterators"; + +describe("util.iterators", function () + describe("join", function () + it("should produce a joined iterator", function () + local expect = { "a", "b", "c", 1, 2, 3 }; + local output = {}; + for x in iter.join(iter.values({"a", "b", "c"})):append(iter.values({1, 2, 3})) do + table.insert(output, x); + end + assert.same(output, expect); + end); + end); +end); diff --git a/spec/util_jid_spec.lua b/spec/util_jid_spec.lua new file mode 100644 index 00000000..c075212f --- /dev/null +++ b/spec/util_jid_spec.lua @@ -0,0 +1,146 @@ + +local jid = require "util.jid"; + +describe("util.jid", function() + describe("#join()", function() + it("should work", function() + assert.are.equal(jid.join("a", "b", "c"), "a@b/c", "builds full JID"); + assert.are.equal(jid.join("a", "b", nil), "a@b", "builds bare JID"); + assert.are.equal(jid.join(nil, "b", "c"), "b/c", "builds full host JID"); + assert.are.equal(jid.join(nil, "b", nil), "b", "builds bare host JID"); + assert.are.equal(jid.join(nil, nil, nil), nil, "invalid JID is nil"); + assert.are.equal(jid.join("a", nil, nil), nil, "invalid JID is nil"); + assert.are.equal(jid.join(nil, nil, "c"), nil, "invalid JID is nil"); + assert.are.equal(jid.join("a", nil, "c"), nil, "invalid JID is nil"); + end); + end); + describe("#split()", function() + it("should work", function() + local function test(input_jid, expected_node, expected_server, expected_resource) + local rnode, rserver, rresource = jid.split(input_jid); + assert.are.equal(expected_node, rnode, "split("..tostring(input_jid)..") failed"); + assert.are.equal(expected_server, rserver, "split("..tostring(input_jid)..") failed"); + assert.are.equal(expected_resource, rresource, "split("..tostring(input_jid)..") failed"); + end + + -- Valid JIDs + test("node@server", "node", "server", nil ); + test("node@server/resource", "node", "server", "resource" ); + test("server", nil, "server", nil ); + test("server/resource", nil, "server", "resource" ); + test("server/resource@foo", nil, "server", "resource@foo" ); + test("server/resource@foo/bar", nil, "server", "resource@foo/bar"); + + -- Always invalid JIDs + test(nil, nil, nil, nil); + test("node@/server", nil, nil, nil); + test("@server", nil, nil, nil); + test("@server/resource", nil, nil, nil); + test("@/resource", nil, nil, nil); + end); + end); + + + describe("#bare()", function() + it("should work", function() + assert.are.equal(jid.bare("user@host"), "user@host", "bare JID remains bare"); + assert.are.equal(jid.bare("host"), "host", "Host JID remains host"); + assert.are.equal(jid.bare("host/resource"), "host", "Host JID with resource becomes host"); + assert.are.equal(jid.bare("user@host/resource"), "user@host", "user@host JID with resource becomes user@host"); + assert.are.equal(jid.bare("user@/resource"), nil, "invalid JID is nil"); + assert.are.equal(jid.bare("@/resource"), nil, "invalid JID is nil"); + assert.are.equal(jid.bare("@/"), nil, "invalid JID is nil"); + assert.are.equal(jid.bare("/"), nil, "invalid JID is nil"); + assert.are.equal(jid.bare(""), nil, "invalid JID is nil"); + assert.are.equal(jid.bare("@"), nil, "invalid JID is nil"); + assert.are.equal(jid.bare("user@"), nil, "invalid JID is nil"); + assert.are.equal(jid.bare("user@@"), nil, "invalid JID is nil"); + assert.are.equal(jid.bare("user@@host"), nil, "invalid JID is nil"); + assert.are.equal(jid.bare("user@@host/resource"), nil, "invalid JID is nil"); + assert.are.equal(jid.bare("user@host/"), nil, "invalid JID is nil"); + end); + end); + + describe("#compare()", function() + it("should work", function() + assert.are.equal(jid.compare("host", "host"), true, "host should match"); + assert.are.equal(jid.compare("host", "other-host"), false, "host should not match"); + assert.are.equal(jid.compare("other-user@host/resource", "host"), true, "host should match"); + assert.are.equal(jid.compare("other-user@host", "user@host"), false, "user should not match"); + assert.are.equal(jid.compare("user@host", "host"), true, "host should match"); + assert.are.equal(jid.compare("user@host/resource", "host"), true, "host should match"); + assert.are.equal(jid.compare("user@host/resource", "user@host"), true, "user and host should match"); + assert.are.equal(jid.compare("user@other-host", "host"), false, "host should not match"); + assert.are.equal(jid.compare("user@other-host", "user@host"), false, "host should not match"); + end); + end); + + it("should work with nodes", function() + local function test(_jid, expected_node) + assert.are.equal(jid.node(_jid), expected_node, "Unexpected node for "..tostring(_jid)); + end + + test("example.com", nil); + test("foo.example.com", nil); + test("foo.example.com/resource", nil); + test("foo.example.com/some resource", nil); + test("foo.example.com/some@resource", nil); + + test("foo@foo.example.com/some@resource", "foo"); + test("foo@example/some@resource", "foo"); + + test("foo@example/@resource", "foo"); + test("foo@example@resource", nil); + test("foo@example", "foo"); + test("foo", nil); + + test(nil, nil); + end); + + it("should work with hosts", function() + local function test(_jid, expected_host) + assert.are.equal(jid.host(_jid), expected_host, "Unexpected host for "..tostring(_jid)); + end + + test("example.com", "example.com"); + test("foo.example.com", "foo.example.com"); + test("foo.example.com/resource", "foo.example.com"); + test("foo.example.com/some resource", "foo.example.com"); + test("foo.example.com/some@resource", "foo.example.com"); + + test("foo@foo.example.com/some@resource", "foo.example.com"); + test("foo@example/some@resource", "example"); + + test("foo@example/@resource", "example"); + test("foo@example@resource", nil); + test("foo@example", "example"); + test("foo", "foo"); + + test(nil, nil); + end); + + it("should work with resources", function() + local function test(_jid, expected_resource) + assert.are.equal(jid.resource(_jid), expected_resource, "Unexpected resource for "..tostring(_jid)); + end + + test("example.com", nil); + test("foo.example.com", nil); + test("foo.example.com/resource", "resource"); + test("foo.example.com/some resource", "some resource"); + test("foo.example.com/some@resource", "some@resource"); + + test("foo@foo.example.com/some@resource", "some@resource"); + test("foo@example/some@resource", "some@resource"); + + test("foo@example/@resource", "@resource"); + test("foo@example@resource", nil); + test("foo@example", nil); + test("foo", nil); + test("/foo", nil); + test("@x/foo", nil); + test("@/foo", nil); + + test(nil, nil); + end); +end); diff --git a/spec/util_json_spec.lua b/spec/util_json_spec.lua new file mode 100644 index 00000000..dc66c7ba --- /dev/null +++ b/spec/util_json_spec.lua @@ -0,0 +1,70 @@ + +local json = require "util.json"; + +describe("util.json", function() + describe("#encode()", function() + it("should work", function() + local function test(f, j, e) + if e then + assert.are.equal(f(j), e); + end + assert.are.equal(f(j), f(json.decode(f(j)))); + end + test(json.encode, json.null, "null") + test(json.encode, {}, "{}") + test(json.encode, {a=1}); + test(json.encode, {a={1,2,3}}); + test(json.encode, {1}, "[1]"); + end); + end); + + describe("#decode()", function() + it("should work", function() + local empty_array = json.decode("[]"); + assert.are.equal(type(empty_array), "table"); + assert.are.equal(#empty_array, 0); + assert.are.equal(next(empty_array), nil); + end); + end); + + describe("testcases", function() + + local valid_data = {}; + local invalid_data = {}; + + local skip = "fail1.json fail9.json fail18.json fail15.json fail13.json fail25.json fail26.json fail27.json fail28.json fail17.json pass1.json"; + + setup(function() + local lfs = require "lfs"; + local path = "spec/json"; + for name in lfs.dir(path) do + if name:match("%.json$") then + local f = assert(io.open(path.."/"..name)); + local content = assert(f:read("*a")); + assert(f:close()); + if skip:find(name) then + -- Skip + elseif name:match("^pass") then + valid_data[name] = content; + elseif name:match("^fail") then + invalid_data[name] = content; + end + end + end + end) + + it("should pass valid testcases", function() + for name, content in pairs(valid_data) do + local parsed, err = json.decode(content); + assert(parsed, name..": "..tostring(err)); + end + end); + + it("should fail invalid testcases", function() + for name, content in pairs(invalid_data) do + local parsed, err = json.decode(content); + assert(not parsed, name..": "..tostring(err)); + end + end); + end) +end); diff --git a/spec/util_multitable_spec.lua b/spec/util_multitable_spec.lua new file mode 100644 index 00000000..40759f7a --- /dev/null +++ b/spec/util_multitable_spec.lua @@ -0,0 +1,60 @@ + +local multitable = require "util.multitable"; + +describe("util.multitable", function() + describe("#new()", function() + it("should create a multitable", function() + local mt = multitable.new(); + assert.is_table(mt, "Multitable is a table"); + assert.is_function(mt.add, "Multitable has method add"); + assert.is_function(mt.get, "Multitable has method get"); + assert.is_function(mt.remove, "Multitable has method remove"); + end); + end); + + describe("#get()", function() + it("should allow getting correctly", function() + local function has_items(list, ...) + local should_have = {}; + if select('#', ...) > 0 then + assert.is_table(list, "has_items: list is table", 3); + else + assert.is.falsy(list and #list > 0, "No items, and no list"); + return true, "has-all"; + end + for n=1,select('#', ...) do should_have[select(n, ...)] = true; end + for _, item in ipairs(list) do + if not should_have[item] then return false, "too-many"; end + should_have[item] = nil; + end + if next(should_have) then + return false, "not-enough"; + end + return true, "has-all"; + end + local function assert_has_all(message, list, ...) + return assert.are.equal(select(2, has_items(list, ...)), "has-all", message or "List has all expected items, and no more", 2); + end + + local mt = multitable.new(); + + local trigger1, trigger2, trigger3 = {}, {}, {}; + local item1, item2, item3 = {}, {}, {}; + + assert_has_all("Has no items with trigger1", mt:get(trigger1)); + + + mt:add(1, 2, 3, item1); + + assert_has_all("Has item1 for 1, 2, 3", mt:get(1, 2, 3), item1); + end); + end); + + -- Doesn't support nil + --[[ mt:add(nil, item1); + mt:add(nil, item2); + mt:add(nil, item3); + + assert_has_all("Has all items with (nil)", mt:get(nil), item1, item2, item3); + ]] +end); diff --git a/spec/util_pubsub_spec.lua b/spec/util_pubsub_spec.lua new file mode 100644 index 00000000..1c9a9e02 --- /dev/null +++ b/spec/util_pubsub_spec.lua @@ -0,0 +1,67 @@ +local pubsub; +setup(function () + pubsub = require "util.pubsub"; +end); + +describe("util.pubsub", function () + describe("simple node creation and deletion", function () + -- Roughly a port of scansion/scripts/pubsub_createdelete.scs + local service = pubsub.new(); + + describe("#create", function () + it("creates a new node", function () + assert.truthy(service:create("princely_musings", true)); + end); + + it("fails to create the same node again", function () + assert.falsy(service:create("princely_musings", true)); + end); + end); + + describe("#delete", function () + it("deletes the node", function () + assert.truthy(service:delete("princely_musings", true)); + end); + + it("can't delete an already deleted node", function () + assert.falsy(service:delete("princely_musings", true)); + end); + end); + end); + + describe("simple publishing", function () + local broadcaster = spy.new(function () end); + local service = pubsub.new({ + broadcaster = broadcaster; + capabilities = { + none = { + subscribe = true; + be_subscribed = true; + }; + } + }); + + it("creates a node", function () + assert.truthy(service:create("node", true)); + end); + + it("lets someone subscribe", function () + assert.truthy(service:add_subscription("node", true, "someone")); + end); + + it("publishes an item", function () + assert.truthy(service:publish("node", true, "1", "item 1")); + end); + + it("called the broadcaster", function () + assert.spy(broadcaster).was_called(); + end); + + it("should return one item", function () + local ok, ret = service:get_items("node", true); + assert.truthy(ok); + assert.same({ "1", ["1"] = "item 1" }, ret); + end); + + end); +end); diff --git a/spec/util_queue_spec.lua b/spec/util_queue_spec.lua new file mode 100644 index 00000000..7cd3d695 --- /dev/null +++ b/spec/util_queue_spec.lua @@ -0,0 +1,103 @@ + +local queue = require "util.queue"; + +describe("util.queue", function() + describe("#new()", function() + it("should work", function() + + do + local q = queue.new(10); + + assert.are.equal(q.size, 10); + assert.are.equal(q:count(), 0); + + assert.is_true(q:push("one")); + assert.is_true(q:push("two")); + assert.is_true(q:push("three")); + + for i = 4, 10 do + assert.is_true(q:push("hello")); + assert.are.equal(q:count(), i, "count is not "..i.."("..q:count()..")"); + end + assert.are.equal(q:push("hello"), nil, "queue overfull!"); + assert.are.equal(q:push("hello"), nil, "queue overfull!"); + assert.are.equal(q:pop(), "one", "queue item incorrect"); + assert.are.equal(q:pop(), "two", "queue item incorrect"); + assert.is_true(q:push("hello")); + assert.is_true(q:push("hello")); + assert.are.equal(q:pop(), "three", "queue item incorrect"); + assert.is_true(q:push("hello")); + assert.are.equal(q:push("hello"), nil, "queue overfull!"); + assert.are.equal(q:push("hello"), nil, "queue overfull!"); + + assert.are.equal(q:count(), 10, "queue count incorrect"); + + for _ = 1, 10 do + assert.are.equal(q:pop(), "hello", "queue item incorrect"); + end + + assert.are.equal(q:count(), 0, "queue count incorrect"); + assert.are.equal(q:pop(), nil, "empty queue pops non-nil result"); + assert.are.equal(q:count(), 0, "popping empty queue affects count"); + + assert.are.equal(q:peek(), nil, "empty queue peeks non-nil result"); + assert.are.equal(q:count(), 0, "peeking empty queue affects count"); + + assert.is_true(q:push(1)); + for i = 1, 1001 do + assert.are.equal(q:pop(), i); + assert.are.equal(q:count(), 0); + assert.is_true(q:push(i+1)); + assert.are.equal(q:count(), 1); + end + assert.are.equal(q:pop(), 1002); + assert.is_true(q:push(1)); + for i = 1, 1000 do + assert.are.equal(q:pop(), i); + assert.is_true(q:push(i+1)); + end + assert.are.equal(q:pop(), 1001); + assert.are.equal(q:count(), 0); + end + + do + -- Test queues that purge old items when pushing to a full queue + local q = queue.new(10, true); + + for i = 1, 10 do + q:push(i); + end + + assert.are.equal(q:count(), 10); + + assert.is_true(q:push(11)); + assert.are.equal(q:count(), 10); + assert.are.equal(q:pop(), 2); -- First item should have been purged + assert.are.equal(q:peek(), 3); + + for i = 12, 32 do + assert.is_true(q:push(i)); + end + + assert.are.equal(q:count(), 10); + assert.are.equal(q:pop(), 23); + end + + do + -- Test iterator + local q = queue.new(10, true); + + for i = 1, 10 do + q:push(i); + end + + local i = 0; + for item in q:items() do + i = i + 1; + assert.are.equal(item, i, "unexpected item returned by iterator") + end + end + + end); + end); +end); diff --git a/spec/util_random_spec.lua b/spec/util_random_spec.lua new file mode 100644 index 00000000..c080a2c9 --- /dev/null +++ b/spec/util_random_spec.lua @@ -0,0 +1,19 @@ + +local random = require "util.random"; + +describe("util.random", function() + describe("#bytes()", function() + it("should return a string", function() + assert.is_string(random.bytes(16)); + end); + + it("should return the requested number of bytes", function() + -- Makes no attempt at testing how random the bytes are, + -- just that it returns the number of bytes requested + + for i = 1, 20 do + assert.are.equal(2^i, #random.bytes(2^i)); + end + end); + end); +end); diff --git a/spec/util_rfc6724_spec.lua b/spec/util_rfc6724_spec.lua new file mode 100644 index 00000000..30e935b6 --- /dev/null +++ b/spec/util_rfc6724_spec.lua @@ -0,0 +1,97 @@ + +local rfc6724 = require "util.rfc6724"; +local new_ip = require"util.ip".new_ip; + +describe("util.rfc6724", function() + describe("#source()", function() + it("should work", function() + assert.are.equal(rfc6724.source(new_ip("2001:db8:1::1", "IPv6"), + {new_ip("2001:db8:3::1", "IPv6"), new_ip("fe80::1", "IPv6")}).addr, + "2001:db8:3::1", + "prefer appropriate scope"); + assert.are.equal(rfc6724.source(new_ip("ff05::1", "IPv6"), + {new_ip("2001:db8:3::1", "IPv6"), new_ip("fe80::1", "IPv6")}).addr, + "2001:db8:3::1", + "prefer appropriate scope"); + assert.are.equal(rfc6724.source(new_ip("2001:db8:1::1", "IPv6"), + {new_ip("2001:db8:1::1", "IPv6"), new_ip("2001:db8:2::1", "IPv6")}).addr, + "2001:db8:1::1", + "prefer same address"); -- "2001:db8:1::1" should be marked "deprecated" here, we don't handle that right now + assert.are.equal(rfc6724.source(new_ip("fe80::1", "IPv6"), + {new_ip("fe80::2", "IPv6"), new_ip("2001:db8:1::1", "IPv6")}).addr, + "fe80::2", + "prefer appropriate scope"); -- "fe80::2" should be marked "deprecated" here, we don't handle that right now + assert.are.equal(rfc6724.source(new_ip("2001:db8:1::1", "IPv6"), + {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3::2", "IPv6")}).addr, + "2001:db8:1::2", + "longest matching prefix"); + --[[ "2001:db8:1::2" should be a care-of address and "2001:db8:3::2" a home address, we can't handle this and would fail + assert.are.equal(rfc6724.source(new_ip("2001:db8:1::1", "IPv6"), + {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3::2", "IPv6")}).addr, + "2001:db8:3::2", + "prefer home address"); + ]] + assert.are.equal(rfc6724.source(new_ip("2002:c633:6401::1", "IPv6"), + {new_ip("2002:c633:6401::d5e3:7953:13eb:22e8", "IPv6"), new_ip("2001:db8:1::2", "IPv6")}).addr, + "2002:c633:6401::d5e3:7953:13eb:22e8", + "prefer matching label"); -- "2002:c633:6401::d5e3:7953:13eb:22e8" should be marked "temporary" here, we don't handle that right now + assert.are.equal(rfc6724.source(new_ip("2001:db8:1::d5e3:0:0:1", "IPv6"), + {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:1::d5e3:7953:13eb:22e8", "IPv6")}).addr, + "2001:db8:1::d5e3:7953:13eb:22e8", + "prefer temporary address") -- "2001:db8:1::2" should be marked "public" and "2001:db8:1::d5e3:7953:13eb:22e8" should be marked "temporary" here, we don't handle that right now + end); + end); + describe("#destination()", function() + it("should work", function() + local order; + order = rfc6724.destination({new_ip("2001:db8:1::1", "IPv6"), new_ip("198.51.100.121", "IPv4")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::1", "IPv6"), new_ip("169.254.13.78", "IPv4")}) + assert.are.equal(order[1].addr, "2001:db8:1::1", "prefer matching scope"); + assert.are.equal(order[2].addr, "198.51.100.121", "prefer matching scope"); + + order = rfc6724.destination({new_ip("2001:db8:1::1", "IPv6"), new_ip("198.51.100.121", "IPv4")}, + {new_ip("fe80::1", "IPv6"), new_ip("198.51.100.117", "IPv4")}) + assert.are.equal(order[1].addr, "198.51.100.121", "prefer matching scope"); + assert.are.equal(order[2].addr, "2001:db8:1::1", "prefer matching scope"); + + order = rfc6724.destination({new_ip("2001:db8:1::1", "IPv6"), new_ip("10.1.2.3", "IPv4")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::1", "IPv6"), new_ip("10.1.2.4", "IPv4")}) + assert.are.equal(order[1].addr, "2001:db8:1::1", "prefer higher precedence"); + assert.are.equal(order[2].addr, "10.1.2.3", "prefer higher precedence"); + + order = rfc6724.destination({new_ip("2001:db8:1::1", "IPv6"), new_ip("fe80::1", "IPv6")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert.are.equal(order[1].addr, "fe80::1", "prefer smaller scope"); + assert.are.equal(order[2].addr, "2001:db8:1::1", "prefer smaller scope"); + + --[[ "2001:db8:1::2" and "fe80::2" should be marked "care-of address", while "2001:db8:3::1" should be marked "home address", we can't currently handle this and would fail the test + order = rfc6724.destination({new_ip("2001:db8:1::1", "IPv6"), new_ip("fe80::1", "IPv6")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3::1", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert.are.equal(order[1].addr, "2001:db8:1::1", "prefer home address"); + assert.are.equal(order[2].addr, "fe80::1", "prefer home address"); + ]] + + --[[ "fe80::2" should be marked "deprecated", we can't currently handle this and would fail the test + order = rfc6724.destination({new_ip("2001:db8:1::1", "IPv6"), new_ip("fe80::1", "IPv6")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert.are.equal(order[1].addr, "2001:db8:1::1", "avoid deprecated addresses"); + assert.are.equal(order[2].addr, "fe80::1", "avoid deprecated addresses"); + ]] + + order = rfc6724.destination({new_ip("2001:db8:1::1", "IPv6"), new_ip("2001:db8:3ffe::1", "IPv6")}, + {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3f44::2", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert.are.equal(order[1].addr, "2001:db8:1::1", "longest matching prefix"); + assert.are.equal(order[2].addr, "2001:db8:3ffe::1", "longest matching prefix"); + + order = rfc6724.destination({new_ip("2002:c633:6401::1", "IPv6"), new_ip("2001:db8:1::1", "IPv6")}, + {new_ip("2002:c633:6401::2", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert.are.equal(order[1].addr, "2002:c633:6401::1", "prefer matching label"); + assert.are.equal(order[2].addr, "2001:db8:1::1", "prefer matching label"); + + order = rfc6724.destination({new_ip("2002:c633:6401::1", "IPv6"), new_ip("2001:db8:1::1", "IPv6")}, + {new_ip("2002:c633:6401::2", "IPv6"), new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::2", "IPv6")}) + assert.are.equal(order[1].addr, "2001:db8:1::1", "prefer higher precedence"); + assert.are.equal(order[2].addr, "2002:c633:6401::1", "prefer higher precedence"); + end); + end); +end); diff --git a/spec/util_stanza_spec.lua b/spec/util_stanza_spec.lua new file mode 100644 index 00000000..36bf213f --- /dev/null +++ b/spec/util_stanza_spec.lua @@ -0,0 +1,238 @@ + +local st = require "util.stanza"; + +describe("util.stanza", function() + describe("#preserialize()", function() + it("should work", function() + local stanza = st.stanza("message", { a = "a" }); + local stanza2 = st.preserialize(stanza); + assert.is_string(stanza2 and stanza.name, "preserialize returns a stanza"); + assert.is_nil(stanza2.tags, "Preserialized stanza has no tag list"); + assert.is_nil(stanza2.last_add, "Preserialized stanza has no last_add marker"); + assert.is_nil(getmetatable(stanza2), "Preserialized stanza has no metatable"); + end); + end); + + describe("#preserialize()", function() + it("should work", function() + local stanza = st.stanza("message", { a = "a" }); + local stanza2 = st.deserialize(st.preserialize(stanza)); + assert.is_string(stanza2 and stanza.name, "deserialize returns a stanza"); + assert.is_table(stanza2.attr, "Deserialized stanza has attributes"); + assert.are.equal(stanza2.attr.a, "a", "Deserialized stanza retains attributes"); + assert.is_table(getmetatable(stanza2), "Deserialized stanza has metatable"); + end); + end); + + describe("#stanza()", function() + it("should work", function() + local s = st.stanza("foo", { xmlns = "myxmlns", a = "attr-a" }); + assert.are.equal(s.name, "foo"); + assert.are.equal(s.attr.xmlns, "myxmlns"); + assert.are.equal(s.attr.a, "attr-a"); + + local s1 = st.stanza("s1"); + assert.are.equal(s1.name, "s1"); + assert.are.equal(s1.attr.xmlns, nil); + assert.are.equal(#s1, 0); + assert.are.equal(#s1.tags, 0); + + s1:tag("child1"); + assert.are.equal(#s1.tags, 1); + assert.are.equal(s1.tags[1].name, "child1"); + + s1:tag("grandchild1"):up(); + assert.are.equal(#s1.tags, 1); + assert.are.equal(s1.tags[1].name, "child1"); + assert.are.equal(#s1.tags[1], 1); + assert.are.equal(s1.tags[1][1].name, "grandchild1"); + + s1:up():tag("child2"); + assert.are.equal(#s1.tags, 2, tostring(s1)); + assert.are.equal(s1.tags[1].name, "child1"); + assert.are.equal(s1.tags[2].name, "child2"); + assert.are.equal(#s1.tags[1], 1); + assert.are.equal(s1.tags[1][1].name, "grandchild1"); + + s1:up():text("Hello world"); + assert.are.equal(#s1.tags, 2); + assert.are.equal(#s1, 3); + assert.are.equal(s1.tags[1].name, "child1"); + assert.are.equal(s1.tags[2].name, "child2"); + assert.are.equal(#s1.tags[1], 1); + assert.are.equal(s1.tags[1][1].name, "grandchild1"); + end); + it("should work with unicode values", function () + local s = st.stanza("Объект", { xmlns = "myxmlns", ["Объект"] = "&" }); + assert.are.equal(s.name, "Объект"); + assert.are.equal(s.attr.xmlns, "myxmlns"); + assert.are.equal(s.attr["Объект"], "&"); + end); + it("should allow :text() with nil and empty strings", function () + local s_control = st.stanza("foo"); + assert.same(st.stanza("foo"):text(), s_control); + assert.same(st.stanza("foo"):text(nil), s_control); + assert.same(st.stanza("foo"):text(""), s_control); + end); + end); + + describe("#message()", function() + it("should work", function() + local m = st.message(); + assert.are.equal(m.name, "message"); + end); + end); + + describe("#iq()", function() + it("should work", function() + local i = st.iq(); + assert.are.equal(i.name, "iq"); + end); + end); + + describe("#iq()", function() + it("should work", function() + local p = st.presence(); + assert.are.equal(p.name, "presence"); + end); + end); + + describe("#reply()", function() + it("should work for <s>", function() + -- Test stanza + local s = st.stanza("s", { to = "touser", from = "fromuser", id = "123" }) + :tag("child1"); + -- Make reply stanza + local r = st.reply(s); + assert.are.equal(r.name, s.name); + assert.are.equal(r.id, s.id); + assert.are.equal(r.attr.to, s.attr.from); + assert.are.equal(r.attr.from, s.attr.to); + assert.are.equal(#r.tags, 0, "A reply should not include children of the original stanza"); + end); + + it("should work for <iq get>", function() + -- Test stanza + local s = st.stanza("iq", { to = "touser", from = "fromuser", id = "123", type = "get" }) + :tag("child1"); + -- Make reply stanza + local r = st.reply(s); + assert.are.equal(r.name, s.name); + assert.are.equal(r.id, s.id); + assert.are.equal(r.attr.to, s.attr.from); + assert.are.equal(r.attr.from, s.attr.to); + assert.are.equal(r.attr.type, "result"); + assert.are.equal(#r.tags, 0, "A reply should not include children of the original stanza"); + end); + + it("should work for <iq set>", function() + -- Test stanza + local s = st.stanza("iq", { to = "touser", from = "fromuser", id = "123", type = "set" }) + :tag("child1"); + -- Make reply stanza + local r = st.reply(s); + assert.are.equal(r.name, s.name); + assert.are.equal(r.id, s.id); + assert.are.equal(r.attr.to, s.attr.from); + assert.are.equal(r.attr.from, s.attr.to); + assert.are.equal(r.attr.type, "result"); + assert.are.equal(#r.tags, 0, "A reply should not include children of the original stanza"); + end); + end); + + describe("#error_reply()", function() + it("should work for <s>", function() + -- Test stanza + local s = st.stanza("s", { to = "touser", from = "fromuser", id = "123" }) + :tag("child1"); + -- Make reply stanza + local r = st.error_reply(s, "cancel", "service-unavailable"); + assert.are.equal(r.name, s.name); + assert.are.equal(r.id, s.id); + assert.are.equal(r.attr.to, s.attr.from); + assert.are.equal(r.attr.from, s.attr.to); + assert.are.equal(#r.tags, 1); + assert.are.equal(r.tags[1].tags[1].name, "service-unavailable"); + end); + + it("should work for <iq get>", function() + -- Test stanza + local s = st.stanza("iq", { to = "touser", from = "fromuser", id = "123", type = "get" }) + :tag("child1"); + -- Make reply stanza + local r = st.error_reply(s, "cancel", "service-unavailable"); + assert.are.equal(r.name, s.name); + assert.are.equal(r.id, s.id); + assert.are.equal(r.attr.to, s.attr.from); + assert.are.equal(r.attr.from, s.attr.to); + assert.are.equal(r.attr.type, "error"); + assert.are.equal(#r.tags, 1); + assert.are.equal(r.tags[1].tags[1].name, "service-unavailable"); + end); + end); + + describe("should reject #invalid", function () + local invalid_names = { + ["empty string"] = "", ["characters"] = "<>"; + } + local invalid_data = { + ["number"] = 1234, ["table"] = {}; + ["utf8"] = string.char(0xF4, 0x90, 0x80, 0x80); + ["nil"] = "nil"; ["boolean"] = true; + }; + + for value_type, value in pairs(invalid_names) do + it(value_type.." in tag names", function () + assert.error_matches(function () + st.stanza(value); + end, value_type); + end); + it(value_type.." in attribute names", function () + assert.error_matches(function () + st.stanza("valid", { [value] = "valid" }); + end, value_type); + end); + end + for value_type, value in pairs(invalid_data) do + if value == "nil" then value = nil; end + it(value_type.." in tag names", function () + assert.error_matches(function () + st.stanza(value); + end, value_type); + end); + it(value_type.." in attribute names", function () + assert.error_matches(function () + st.stanza("valid", { [value] = "valid" }); + end, value_type); + end); + if value ~= nil then + it(value_type.." in attribute values", function () + assert.error_matches(function () + st.stanza("valid", { valid = value }); + end, value_type); + end); + it(value_type.." in text node", function () + assert.error_matches(function () + st.stanza("valid"):text(value); + end, value_type); + end); + end + end + end); + + describe("#is_stanza", function () + -- is_stanza(any) -> boolean + it("identifies stanzas as stanzas", function () + assert.truthy(st.is_stanza(st.stanza("x"))); + end); + it("identifies strings as not stanzas", function () + assert.falsy(st.is_stanza("")); + end); + it("identifies numbers as not stanzas", function () + assert.falsy(st.is_stanza(1)); + end); + it("identifies tables as not stanzas", function () + assert.falsy(st.is_stanza({})); + end); + end); +end); diff --git a/spec/util_throttle_spec.lua b/spec/util_throttle_spec.lua new file mode 100644 index 00000000..75daf1b9 --- /dev/null +++ b/spec/util_throttle_spec.lua @@ -0,0 +1,150 @@ + + +-- Mock util.time +local now = 0; -- wibbly-wobbly... timey-wimey... stuff +local function later(n) + now = now + n; -- time passes at a different rate +end +package.loaded["util.time"] = { + now = function() return now; end +} + + +local throttle = require "util.throttle"; + +describe("util.throttle", function() + describe("#create()", function() + it("should be created with correct values", function() + now = 5; + local a = throttle.create(3, 10); + assert.same(a, { balance = 3, max = 3, rate = 0.3, t = 5 }); + + local a = throttle.create(3, 5); + assert.same(a, { balance = 3, max = 3, rate = 0.6, t = 5 }); + + local a = throttle.create(1, 1); + assert.same(a, { balance = 1, max = 1, rate = 1, t = 5 }); + + local a = throttle.create(10, 10); + assert.same(a, { balance = 10, max = 10, rate = 1, t = 5 }); + + local a = throttle.create(10, 1); + assert.same(a, { balance = 10, max = 10, rate = 10, t = 5 }); + end); + end); + + describe("#update()", function() + it("does nothing when no time has passed, even if balance is not full", function() + now = 5; + local a = throttle.create(10, 10); + for i=1,5 do + a:update(); + assert.same(a, { balance = 10, max = 10, rate = 1, t = 5 }); + end + a.balance = 0; + for i=1,5 do + a:update(); + assert.same(a, { balance = 0, max = 10, rate = 1, t = 5 }); + end + end); + it("updates only time when time passes but balance is full", function() + now = 5; + local a = throttle.create(10, 10); + for i=1,5 do + later(5); + a:update(); + assert.same(a, { balance = 10, max = 10, rate = 1, t = 5 + i*5 }); + end + end); + it("updates balance when balance has room to grow as time passes", function() + now = 5; + local a = throttle.create(10, 10); + a.balance = 0; + assert.same(a, { balance = 0, max = 10, rate = 1, t = 5 }); + + later(1); + a:update(); + assert.same(a, { balance = 1, max = 10, rate = 1, t = 6 }); + + later(3); + a:update(); + assert.same(a, { balance = 4, max = 10, rate = 1, t = 9 }); + + later(10); + a:update(); + assert.same(a, { balance = 10, max = 10, rate = 1, t = 19 }); + end); + it("handles 10 x 0.1s updates the same as 1 x 1s update ", function() + now = 5; + local a = throttle.create(1, 1); + + a.balance = 0; + later(1); + a:update(); + assert.same(a, { balance = 1, max = 1, rate = 1, t = now }); + + a.balance = 0; + for i=1,10 do + later(0.1); + a:update(); + end + assert(math.abs(a.balance - 1) < 0.0001); -- incremental updates cause rouding errors + end); + end); + + -- describe("po") + + describe("#poll()", function() + it("should only allow successful polls until cost is hit", function() + now = 5; + + local a = throttle.create(3, 10); + assert.same(a, { balance = 3, max = 3, rate = 0.3, t = 5 }); + + assert.is_true(a:poll(1)); -- 3 -> 2 + assert.same(a, { balance = 2, max = 3, rate = 0.3, t = 5 }); + + assert.is_true(a:poll(2)); -- 2 -> 1 + assert.same(a, { balance = 0, max = 3, rate = 0.3, t = 5 }); + + assert.is_false(a:poll(1)); -- MEEP, out of credits! + assert.is_false(a:poll(1)); -- MEEP, out of credits! + assert.same(a, { balance = 0, max = 3, rate = 0.3, t = 5 }); + end); + + it("should not allow polls more than the cost", function() + now = 0; + + local a = throttle.create(10, 10); + assert.same(a, { balance = 10, max = 10, rate = 1, t = 0 }); + + assert.is_false(a:poll(11)); + assert.same(a, { balance = 10, max = 10, rate = 1, t = 0 }); + + assert.is_true(a:poll(6)); + assert.same(a, { balance = 4, max = 10, rate = 1, t = 0 }); + + assert.is_false(a:poll(5)); + assert.same(a, { balance = 4, max = 10, rate = 1, t = 0 }); + + -- fractional + assert.is_true(a:poll(3.5)); + assert.same(a, { balance = 0.5, max = 10, rate = 1, t = 0 }); + + assert.is_true(a:poll(0.25)); + assert.same(a, { balance = 0.25, max = 10, rate = 1, t = 0 }); + + assert.is_false(a:poll(0.3)); + assert.same(a, { balance = 0.25, max = 10, rate = 1, t = 0 }); + + assert.is_true(a:poll(0.25)); + assert.same(a, { balance = 0, max = 10, rate = 1, t = 0 }); + + assert.is_false(a:poll(0.1)); + assert.same(a, { balance = 0, max = 10, rate = 1, t = 0 }); + + assert.is_true(a:poll(0)); + assert.same(a, { balance = 0, max = 10, rate = 1, t = 0 }); + end); + end); +end); diff --git a/spec/util_uuid_spec.lua b/spec/util_uuid_spec.lua new file mode 100644 index 00000000..95ae0a20 --- /dev/null +++ b/spec/util_uuid_spec.lua @@ -0,0 +1,25 @@ +-- This tests the format, not the randomness + +local uuid = require "util.uuid"; + +describe("util.uuid", function() + describe("#generate()", function() + it("should work follow the UUID pattern", function() + -- https://tools.ietf.org/html/rfc4122#section-4.4 + + local pattern = "^" .. table.concat({ + string.rep("%x", 8), + string.rep("%x", 4), + "4" .. -- version + string.rep("%x", 3), + "[89ab]" .. -- reserved bits of 1 and 0 + string.rep("%x", 3), + string.rep("%x", 12), + }, "%-") .. "$"; + + for _ = 1, 100 do + assert.is_string(uuid.generate():match(pattern)); + end + end); + end); +end); diff --git a/spec/util_xml_spec.lua b/spec/util_xml_spec.lua new file mode 100644 index 00000000..11820894 --- /dev/null +++ b/spec/util_xml_spec.lua @@ -0,0 +1,20 @@ + +local xml = require "util.xml"; + +describe("util.xml", function() + describe("#parse()", function() + it("should work", function() + local x = +[[<x xmlns:a="b"> + <y xmlns:a="c"> <!-- this overwrites 'a' --> + <a:z/> + </y> + <a:z/> <!-- prefix 'a' is nil here, but should be 'b' --> +</x> +]] + local stanza = xml.parse(x); + assert.are.equal(stanza.tags[2].attr.xmlns, "b"); + assert.are.equal(stanza.tags[2].namespaces["a"], "b"); + end); + end); +end); diff --git a/spec/util_xmppstream_spec.lua b/spec/util_xmppstream_spec.lua new file mode 100644 index 00000000..f03a806e --- /dev/null +++ b/spec/util_xmppstream_spec.lua @@ -0,0 +1,90 @@ + +local xmppstream = require "util.xmppstream"; + +describe("util.xmppstream", function() + describe("#new()", function() + it("should work", function() + local function test(xml, expect_success, ex) + local stanzas = {}; + local session = { notopen = true }; + local callbacks = { + stream_ns = "streamns"; + stream_tag = "stream"; + default_ns = "stanzans"; + streamopened = function (_session) + assert.are.equal(session, _session); + assert.are.equal(session.notopen, true); + _session.notopen = nil; + return true; + end; + handlestanza = function (_session, stanza) + assert.are.equal(session, _session); + assert.are.equal(_session.notopen, nil); + table.insert(stanzas, stanza); + end; + streamclosed = function (_session) + assert.are.equal(session, _session); + assert.are.equal(_session.notopen, nil); + _session.notopen = nil; + end; + } + if type(ex) == "table" then + for k, v in pairs(ex) do + if k ~= "_size_limit" then + callbacks[k] = v; + end + end + end + local stream = xmppstream.new(session, callbacks, size_limit); + local ok, err = pcall(function () + assert(stream:feed(xml)); + end); + + if ok and type(expect_success) == "function" then + expect_success(stanzas); + end + assert.are.equal(not not ok, not not expect_success, "Expected "..(expect_success and ("success ("..tostring(err)..")") or "failure")); + end + + local function test_stanza(stanza, expect_success, ex) + return test([[<stream:stream xmlns:stream="streamns" xmlns="stanzans">]]..stanza, expect_success, ex); + end + + test([[<stream:stream xmlns:stream="streamns"/>]], true); + test([[<stream xmlns="streamns"/>]], true); + + test([[<stream1 xmlns="streamns"/>]], false); + test([[<stream xmlns="streamns1"/>]], false); + test("<>", false); + + test_stanza("<message/>", function (stanzas) + assert.are.equal(#stanzas, 1); + assert.are.equal(stanzas[1].name, "message"); + end); + test_stanza("< message>>>>/>\n", false); + + test_stanza([[<x xmlns:a="b"> + <y xmlns:a="c"> + <a:z/> + </y> + <a:z/> + </x>]], function (stanzas) + assert.are.equal(#stanzas, 1); + local s = stanzas[1]; + assert.are.equal(s.name, "x"); + assert.are.equal(#s.tags, 2); + + assert.are.equal(s.tags[1].name, "y"); + assert.are.equal(s.tags[1].attr.xmlns, nil); + + assert.are.equal(s.tags[1].tags[1].name, "z"); + assert.are.equal(s.tags[1].tags[1].attr.xmlns, "c"); + + assert.are.equal(s.tags[2].name, "z"); + assert.are.equal(s.tags[2].attr.xmlns, "b"); + + assert.are.equal(s.namespaces, nil); + end); + end); + end); +end); diff --git a/tests/modulemanager_option_conversion.lua b/tests/modulemanager_option_conversion.lua deleted file mode 100644 index 100dbe83..00000000 --- a/tests/modulemanager_option_conversion.lua +++ /dev/null @@ -1,55 +0,0 @@ -package.path = "../?.lua;"..package.path; - -local api = require "core.modulemanager".api; - -local module = setmetatable({}, {__index = api}); -local opt = nil; -function module:log() end -function module:get_option(name) - if name == "opt" then - return opt; - else - return nil; - end -end - -function test_value(value, returns) - opt = value; - assert(module:get_option_number("opt") == returns.number, "number doesn't match"); - assert(module:get_option_string("opt") == returns.string, "string doesn't match"); - assert(module:get_option_boolean("opt") == returns.boolean, "boolean doesn't match"); - - if type(returns.array) == "table" then - local target_array, returned_array = returns.array, module:get_option_array("opt"); - assert(#target_array == #returned_array, "array length doesn't match"); - for i=1,#target_array do - assert(target_array[i] == returned_array[i], "array item doesn't match"); - end - else - assert(module:get_option_array("opt") == returns.array, "array is returned (not nil)"); - end - - if type(returns.set) == "table" then - local target_items, returned_items = set.new(returns.set), module:get_option_set("opt"); - assert(target_items == returned_items, "set doesn't match"); - else - assert(module:get_option_set("opt") == returns.set, "set is returned (not nil)"); - end -end - -test_value(nil, {}); - -test_value(true, { boolean = true, string = "true", array = {true}, set = {true} }); -test_value(false, { boolean = false, string = "false", array = {false}, set = {false} }); -test_value("true", { boolean = true, string = "true", array = {"true"}, set = {"true"} }); -test_value("false", { boolean = false, string = "false", array = {"false"}, set = {"false"} }); -test_value(1, { boolean = true, string = "1", array = {1}, set = {1}, number = 1 }); -test_value(0, { boolean = false, string = "0", array = {0}, set = {0}, number = 0 }); - -test_value("hello world", { string = "hello world", array = {"hello world"}, set = {"hello world"} }); -test_value(1234, { string = "1234", number = 1234, array = {1234}, set = {1234} }); - -test_value({1, 2, 3}, { boolean = true, string = "1", number = 1, array = {1, 2, 3}, set = {1, 2, 3} }); -test_value({1, 2, 3, 3, 4}, {boolean = true, string = "1", number = 1, array = {1, 2, 3, 3, 4}, set = {1, 2, 3, 4} }); -test_value({0, 1, 2, 3}, { boolean = false, string = "0", number = 0, array = {0, 1, 2, 3}, set = {0, 1, 2, 3} }); - diff --git a/tests/reports/empty b/tests/reports/empty deleted file mode 100644 index 0e3c9a08..00000000 --- a/tests/reports/empty +++ /dev/null @@ -1 +0,0 @@ -This file was intentionally left blank. diff --git a/tests/run_tests.bat b/tests/run_tests.bat deleted file mode 100644 index 648081f5..00000000 --- a/tests/run_tests.bat +++ /dev/null @@ -1,10 +0,0 @@ -@echo off
-
-set oldpath=%path%
-set path=%path%;..;..\lualibs
-
-del reports\*.report
-lua test.lua %*
-
-set path=%oldpath%
-set oldpath=
\ No newline at end of file diff --git a/tests/run_tests.sh b/tests/run_tests.sh deleted file mode 100755 index 7f1ee700..00000000 --- a/tests/run_tests.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/sh -rm reports/*.report -exec lua test.lua "$@" diff --git a/tests/test.lua b/tests/test.lua deleted file mode 100644 index bc33bb76..00000000 --- a/tests/test.lua +++ /dev/null @@ -1,255 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -local tests_passed = true; - -function run_all_tests() - package.loaded["net.connlisteners"] = { get = function () return {} end }; - dotest "util.jid" - dotest "util.multitable" - dotest "util.rfc6724" - dotest "util.http" - dotest "core.stanza_router" - dotest "core.s2smanager" - dotest "core.configmanager" - dotest "util.ip" - dotest "util.json" - dotest "util.stanza" - dotest "util.sasl.scram" - dotest "util.cache" - dotest "util.throttle" - dotest "util.uuid" - dotest "util.random" - dotest "util.xml" - dotest "util.xmppstream" - dotest "util.queue" - dotest "net.http.parser" - - dosingletest("test_sasl.lua", "latin1toutf8"); - dosingletest("test_utf8.lua", "valid"); -end - -local verbosity = tonumber(arg[1]) or 2; - -if os.getenv("WINDIR") then - package.path = package.path..";..\\?.lua"; - package.cpath = package.cpath..";..\\?.dll"; -else - package.path = package.path..";../?.lua"; - package.cpath = package.cpath..";../?.so"; -end - -local _realG = _G; - -require "util.import" - -local envloadfile = require "util.envload".envloadfile; - -local env_mt = { __index = function (t,k) return rawget(_realG, k) or print("WARNING: Attempt to access nil global '"..tostring(k).."'"); end }; -function testlib_new_env(t) - return setmetatable(t or {}, env_mt); -end - -function assert_equal(a, b, message, level) - if not (a == b) then - error("\n assert_equal failed: "..tostring(a).." ~= "..tostring(b)..(message and ("\n Message: "..message) or ""), (level or 1) + 1); - elseif verbosity >= 4 then - print("assert_equal succeeded: "..tostring(a).." == "..tostring(b)); - end -end - -function assert_table(a, message, level) - assert_equal(type(a), "table", message, (level or 1) + 1); -end -function assert_function(a, message, level) - assert_equal(type(a), "function", message, (level or 1) + 1); -end -function assert_string(a, message, level) - assert_equal(type(a), "string", message, (level or 1) + 1); -end -function assert_boolean(a, message) - assert_equal(type(a), "boolean", message); -end -function assert_is(a, message) - assert_equal(not not a, true, message); -end -function assert_is_not(a, message) - assert_equal(not not a, false, message); -end - - -function dosingletest(testname, fname) - local tests = setmetatable({}, { __index = _realG }); - tests.__unit = testname; - tests.__test = fname; - local chunk, err = envloadfile(testname, tests); - if not chunk then - print("WARNING: ", "Failed to load tests for "..testname, err); - return; - end - - local success, err = pcall(chunk); - if not success then - print("WARNING: ", "Failed to initialise tests for "..testname, err); - return; - end - - if type(tests[fname]) ~= "function" then - error(testname.." has no test '"..fname.."'", 0); - end - - - local line_hook, line_info = new_line_coverage_monitor(testname); - debug.sethook(line_hook, "l") - local success, ret = pcall(tests[fname]); - debug.sethook(); - if not success then - tests_passed = false; - print("TEST FAILED! Unit: ["..testname.."] Function: ["..fname.."]"); - print(" Location: "..ret:gsub(":%s*\n", "\n")); - line_info(fname, false, report_file); - elseif verbosity >= 2 then - print("TEST SUCCEEDED: ", testname, fname); - print(string.format("TEST COVERED %d/%d lines", line_info(fname, true, report_file))); - else - line_info(name, success, report_file); - end -end - -function dotest(unitname) - local _fakeG = setmetatable({}, {__index = _realG}); - _fakeG._G = _fakeG; - local tests = setmetatable({}, { __index = _fakeG }); - tests.__unit = unitname; - local chunk, err = envloadfile("test_"..unitname:gsub("%.", "_")..".lua", tests); - if not chunk then - print("WARNING: ", "Failed to load tests for "..unitname, err); - return; - end - - local success, err = pcall(chunk); - if not success then - print("WARNING: ", "Failed to initialise tests for "..unitname, err); - return; - end - if tests.env then setmetatable(tests.env, { __index = _realG }); end - local unit = setmetatable({}, { __index = setmetatable({ _G = tests.env or _fakeG }, { __index = tests.env or _fakeG }) }); - local fn = "../"..unitname:gsub("%.", "/")..".lua"; - local chunk, err = envloadfile(fn, unit); - if not chunk then - print("WARNING: ", "Failed to load module: "..unitname, err); - return; - end - - local oldmodule, old_M = _fakeG.module, _fakeG._M; - _fakeG.module = function () - setmetatable(unit, nil); - unit._M = unit; - end - local success, ret = pcall(chunk); - _fakeG.module, _fakeG._M = oldmodule, old_M; - if not success then - print("WARNING: ", "Failed to initialise module: "..unitname, ret); - return; - end - - if type(ret) == "table" then - for k,v in pairs(ret) do - unit[k] = v; - end - end - - for name, f in pairs(unit) do - local test = rawget(tests, name); - if type(f) ~= "function" then - if verbosity >= 3 then - print("INFO: ", "Skipping "..unitname.."."..name.." because it is not a function"); - end - elseif type(test) ~= "function" then - if verbosity >= 1 then - print("WARNING: ", unitname.."."..name.." has no test!"); - end - else - if verbosity >= 4 then - print("INFO: ", "Testing "..unitname.."."..name); - end - local line_hook, line_info = new_line_coverage_monitor(fn); - debug.sethook(line_hook, "l") - local success, ret = pcall(test, f, unit); - debug.sethook(); - if not success then - tests_passed = false; - print("TEST FAILED! Unit: ["..unitname.."] Function: ["..name.."]"); - print(" Location: "..ret:gsub(":%s*\n", "\n")); - line_info(name, false, report_file); - elseif verbosity >= 2 then - print("TEST SUCCEEDED: ", unitname, name); - print(string.format("TEST COVERED %d/%d lines", line_info(name, true, report_file))); - else - line_info(name, success, report_file); - end - end - end -end - -function runtest(f, msg) - if not f then print("SUBTEST NOT FOUND: "..(msg or "(no description)")); return; end - local success, ret = pcall(f); - if success and verbosity >= 2 then - print("SUBTEST PASSED: "..(msg or "(no description)")); - elseif (not success) and verbosity >= 0 then - tests_passed = false; - print("SUBTEST FAILED: "..(msg or "(no description)")); - error(ret, 0); - end -end - -function new_line_coverage_monitor(file) - local lines_hit, funcs_hit = {}, {}; - local total_lines, covered_lines = 0, 0; - - for line in io.lines(file) do - total_lines = total_lines + 1; - end - - return function (event, line) -- Line hook - if not lines_hit[line] then - local info = debug.getinfo(2, "fSL") - if not info.source:find(file) then return; end - if not funcs_hit[info.func] and info.activelines then - funcs_hit[info.func] = true; - for line in pairs(info.activelines) do - lines_hit[line] = false; -- Marks it as hittable, but not hit yet - end - end - if lines_hit[line] == false then - --print("New line hit: "..line.." in "..debug.getinfo(2, "S").source); - lines_hit[line] = true; - covered_lines = covered_lines + 1; - end - end - end, - function (test_name, success) -- Get info - local fn = file:gsub("^%W*", ""); - local total_active_lines = 0; - local coverage_file = io.open("reports/coverage_"..fn:gsub("%W+", "_")..".report", "a+"); - for line, active in pairs(lines_hit) do - if active ~= nil then total_active_lines = total_active_lines + 1; end - if coverage_file then - if active == false then coverage_file:write(fn, "|", line, "|", name or "", "|miss\n"); - else coverage_file:write(fn, "|", line, "|", name or "", "|", tostring(success), "\n"); end - end - end - if coverage_file then coverage_file:close(); end - return covered_lines, total_active_lines, lines_hit; - end -end - -run_all_tests() - -os.exit(tests_passed and 0 or 1); diff --git a/tests/test_core_configmanager.lua b/tests/test_core_configmanager.lua deleted file mode 100644 index 5bd469c6..00000000 --- a/tests/test_core_configmanager.lua +++ /dev/null @@ -1,33 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - - -function get(get, config) - config.set("example.com", "testkey", 123); - assert_equal(get("example.com", "testkey"), 123, "Retrieving a set key"); - - config.set("*", "testkey1", 321); - assert_equal(get("*", "testkey1"), 321, "Retrieving a set global key"); - assert_equal(get("example.com", "testkey1"), 321, "Retrieving a set key of undefined host, of which only a globally set one exists"); - - config.set("example.com", ""); -- Creates example.com host in config - assert_equal(get("example.com", "testkey1"), 321, "Retrieving a set key, of which only a globally set one exists"); - - assert_equal(get(), nil, "No parameters to get()"); - assert_equal(get("undefined host"), nil, "Getting for undefined host"); - assert_equal(get("undefined host", "undefined key"), nil, "Getting for undefined host & key"); -end - -function set(set, u) - assert_equal(set("*"), false, "Set with no key"); - - assert_equal(set("*", "set_test", "testkey"), true, "Setting a nil global value"); - assert_equal(set("*", "set_test", "testkey", 123), true, "Setting a global value"); -end - diff --git a/tests/test_core_s2smanager.lua b/tests/test_core_s2smanager.lua deleted file mode 100644 index d2dbf830..00000000 --- a/tests/test_core_s2smanager.lua +++ /dev/null @@ -1,50 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -env = { - prosody = { events = require "util.events".new() }; -}; - -function compare_srv_priorities(csp) - local r1 = { priority = 10, weight = 0 } - local r2 = { priority = 100, weight = 0 } - local r3 = { priority = 1000, weight = 2 } - local r4 = { priority = 1000, weight = 2 } - local r5 = { priority = 1000, weight = 5 } - - assert_equal(csp(r1, r1), false); - assert_equal(csp(r1, r2), true); - assert_equal(csp(r1, r3), true); - assert_equal(csp(r1, r4), true); - assert_equal(csp(r1, r5), true); - - assert_equal(csp(r2, r1), false); - assert_equal(csp(r2, r2), false); - assert_equal(csp(r2, r3), true); - assert_equal(csp(r2, r4), true); - assert_equal(csp(r2, r5), true); - - assert_equal(csp(r3, r1), false); - assert_equal(csp(r3, r2), false); - assert_equal(csp(r3, r3), false); - assert_equal(csp(r3, r4), false); - assert_equal(csp(r3, r5), false); - - assert_equal(csp(r4, r1), false); - assert_equal(csp(r4, r2), false); - assert_equal(csp(r4, r3), false); - assert_equal(csp(r4, r4), false); - assert_equal(csp(r4, r5), false); - - assert_equal(csp(r5, r1), false); - assert_equal(csp(r5, r2), false); - assert_equal(csp(r5, r3), true); - assert_equal(csp(r5, r4), true); - assert_equal(csp(r5, r5), false); - -end diff --git a/tests/test_core_stanza_router.lua b/tests/test_core_stanza_router.lua deleted file mode 100644 index ca6b78fc..00000000 --- a/tests/test_core_stanza_router.lua +++ /dev/null @@ -1,232 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -_G.prosody = { full_sessions = {}; bare_sessions = {}; hosts = {}; }; - -function core_process_stanza(core_process_stanza, u) - local stanza = require "util.stanza"; - local s2sout_session = { to_host = "remotehost", from_host = "localhost", type = "s2sout" } - local s2sin_session = { from_host = "remotehost", to_host = "localhost", type = "s2sin", hosts = { ["remotehost"] = { authed = true } } } - local local_host_session = { host = "localhost", type = "local", s2sout = { ["remotehost"] = s2sout_session } } - local local_user_session = { username = "user", host = "localhost", resource = "resource", full_jid = "user@localhost/resource", type = "c2s" } - - _G.prosody.hosts["localhost"] = local_host_session; - _G.prosody.full_sessions["user@localhost/resource"] = local_user_session; - _G.prosody.bare_sessions["user@localhost"] = { sessions = { resource = local_user_session } }; - - -- Test message routing - local function test_message_full_jid() - local env = testlib_new_env(); - local msg = stanza.stanza("message", { to = "user@localhost/resource", type = "chat" }):tag("body"):text("Hello world"); - - local target_routed; - - function env.core_post_stanza(p_origin, p_stanza) - assert_equal(p_origin, local_user_session, "origin of routed stanza is not correct"); - assert_equal(p_stanza, msg, "routed stanza is not correct one: "..p_stanza:pretty_print()); - target_routed = true; - end - - env.hosts = hosts; - env.prosody = { hosts = hosts }; - setfenv(core_process_stanza, env); - assert_equal(core_process_stanza(local_user_session, msg), nil, "core_process_stanza returned incorrect value"); - assert_equal(target_routed, true, "stanza was not routed successfully"); - end - - local function test_message_bare_jid() - local env = testlib_new_env(); - local msg = stanza.stanza("message", { to = "user@localhost", type = "chat" }):tag("body"):text("Hello world"); - - local target_routed; - - function env.core_post_stanza(p_origin, p_stanza) - assert_equal(p_origin, local_user_session, "origin of routed stanza is not correct"); - assert_equal(p_stanza, msg, "routed stanza is not correct one: "..p_stanza:pretty_print()); - target_routed = true; - end - - env.hosts = hosts; - setfenv(core_process_stanza, env); - assert_equal(core_process_stanza(local_user_session, msg), nil, "core_process_stanza returned incorrect value"); - assert_equal(target_routed, true, "stanza was not routed successfully"); - end - - local function test_message_no_to() - local env = testlib_new_env(); - local msg = stanza.stanza("message", { type = "chat" }):tag("body"):text("Hello world"); - - local target_handled; - - function env.core_post_stanza(p_origin, p_stanza) - assert_equal(p_origin, local_user_session, "origin of handled stanza is not correct"); - assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); - target_handled = true; - end - - env.hosts = hosts; - setfenv(core_process_stanza, env); - assert_equal(core_process_stanza(local_user_session, msg), nil, "core_process_stanza returned incorrect value"); - assert_equal(target_handled, true, "stanza was not handled successfully"); - end - - local function test_message_to_remote_bare() - local env = testlib_new_env(); - local msg = stanza.stanza("message", { to = "user@remotehost", type = "chat" }):tag("body"):text("Hello world"); - - local target_routed; - - function env.core_route_stanza(p_origin, p_stanza) - assert_equal(p_origin, local_user_session, "origin of handled stanza is not correct"); - assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); - target_routed = true; - end - - function env.core_post_stanza(...) env.core_route_stanza(...); end - - env.hosts = hosts; - setfenv(core_process_stanza, env); - assert_equal(core_process_stanza(local_user_session, msg), nil, "core_process_stanza returned incorrect value"); - assert_equal(target_routed, true, "stanza was not routed successfully"); - end - - local function test_message_to_remote_server() - local env = testlib_new_env(); - local msg = stanza.stanza("message", { to = "remotehost", type = "chat" }):tag("body"):text("Hello world"); - - local target_routed; - - function env.core_route_stanza(p_origin, p_stanza) - assert_equal(p_origin, local_user_session, "origin of handled stanza is not correct"); - assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); - target_routed = true; - end - - function env.core_post_stanza(...) - env.core_route_stanza(...); - end - - env.hosts = hosts; - setfenv(core_process_stanza, env); - assert_equal(core_process_stanza(local_user_session, msg), nil, "core_process_stanza returned incorrect value"); - assert_equal(target_routed, true, "stanza was not routed successfully"); - end - - --IQ tests - - - local function test_iq_to_remote_server() - local env = testlib_new_env(); - local msg = stanza.stanza("iq", { to = "remotehost", type = "get", id = "id" }):tag("body"):text("Hello world"); - - local target_routed; - - function env.core_route_stanza(p_origin, p_stanza) - assert_equal(p_origin, local_user_session, "origin of handled stanza is not correct"); - assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); - target_routed = true; - end - - function env.core_post_stanza(...) - env.core_route_stanza(...); - end - - env.hosts = hosts; - setfenv(core_process_stanza, env); - assert_equal(core_process_stanza(local_user_session, msg), nil, "core_process_stanza returned incorrect value"); - assert_equal(target_routed, true, "stanza was not routed successfully"); - end - - local function test_iq_error_to_local_user() - local env = testlib_new_env(); - local msg = stanza.stanza("iq", { to = "user@localhost/resource", from = "user@remotehost", type = "error", id = "id" }):tag("error", { type = 'cancel' }):tag("item-not-found", { xmlns='urn:ietf:params:xml:ns:xmpp-stanzas' }); - - local target_routed; - - function env.core_route_stanza(p_origin, p_stanza) - assert_equal(p_origin, s2sin_session, "origin of handled stanza is not correct"); - assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); - target_routed = true; - end - - function env.core_post_stanza(...) - env.core_route_stanza(...); - end - - env.hosts = hosts; - setfenv(core_process_stanza, env); - assert_equal(core_process_stanza(s2sin_session, msg), nil, "core_process_stanza returned incorrect value"); - assert_equal(target_routed, true, "stanza was not routed successfully"); - end - - local function test_iq_to_local_bare() - local env = testlib_new_env(); - local msg = stanza.stanza("iq", { to = "user@localhost", from = "user@localhost", type = "get", id = "id" }):tag("ping", { xmlns = "urn:xmpp:ping:0" }); - - local target_handled; - - function env.core_post_stanza(p_origin, p_stanza) - assert_equal(p_origin, local_user_session, "origin of handled stanza is not correct"); - assert_equal(p_stanza, msg, "handled stanza is not correct one: "..p_stanza:pretty_print()); - target_handled = true; - end - - env.hosts = hosts; - setfenv(core_process_stanza, env); - assert_equal(core_process_stanza(local_user_session, msg), nil, "core_process_stanza returned incorrect value"); - assert_equal(target_handled, true, "stanza was not handled successfully"); - end - - runtest(test_message_full_jid, "Messages with full JID destinations get routed"); - runtest(test_message_bare_jid, "Messages with bare JID destinations get routed"); - runtest(test_message_no_to, "Messages with no destination are handled by the server"); - runtest(test_message_to_remote_bare, "Messages to a remote user are routed by the server"); - runtest(test_message_to_remote_server, "Messages to a remote server's JID are routed"); - - runtest(test_iq_to_remote_server, "iq to a remote server's JID are routed"); - runtest(test_iq_to_local_bare, "iq from a local user to a local user's bare JID are handled"); - runtest(test_iq_error_to_local_user, "iq type=error to a local user's JID are routed"); -end - -function core_route_stanza(core_route_stanza) - local stanza = require "util.stanza"; - local s2sout_session = { to_host = "remotehost", from_host = "localhost", type = "s2sout" } - local s2sin_session = { from_host = "remotehost", to_host = "localhost", type = "s2sin", hosts = { ["remotehost"] = { authed = true } } } - local local_host_session = { host = "localhost", type = "local", s2sout = { ["remotehost"] = s2sout_session }, sessions = {} } - local local_user_session = { username = "user", host = "localhost", resource = "resource", full_jid = "user@localhost/resource", type = "c2s" } - local hosts = { - ["localhost"] = local_host_session; - } - - local function test_iq_result_to_offline_user() - local env = testlib_new_env(); - local msg = stanza.stanza("iq", { to = "user@localhost/foo", from = "user@localhost", type = "result" }):tag("ping", { xmlns = "urn:xmpp:ping:0" }); - local msg2 = stanza.stanza("iq", { to = "user@localhost/foo", from = "user@localhost", type = "error" }):tag("ping", { xmlns = "urn:xmpp:ping:0" }); - --package.loaded["core.usermanager"] = { user_exists = function (user, host) print("RAR!") return true or user == "user" and host == "localhost" and true; end }; - local target_handled, target_replied; - - function env.core_post_stanza(p_origin, p_stanza) - target_handled = true; - end - - function local_user_session.send(data) - --print("Replying with: ", tostring(data)); - --print(debug.traceback()) - target_replied = true; - end - - env.hosts = hosts; - setfenv(core_route_stanza, env); - assert_equal(core_route_stanza(local_user_session, msg), nil, "core_route_stanza returned incorrect value"); - assert_equal(target_handled, nil, "stanza was handled and not dropped"); - assert_equal(target_replied, nil, "stanza was replied to and not dropped"); - package.loaded["core.usermanager"] = nil; - end - - --runtest(test_iq_result_to_offline_user, "iq type=result|error to an offline user are not replied to"); -end diff --git a/tests/test_net_http_parser.lua b/tests/test_net_http_parser.lua deleted file mode 100644 index 1157b5ac..00000000 --- a/tests/test_net_http_parser.lua +++ /dev/null @@ -1,47 +0,0 @@ -local httpstreams = { [[ -GET / HTTP/1.1 -Host: example.com - -]], [[ -HTTP/1.1 200 OK -Content-Length: 0 - -]], [[ -HTTP/1.1 200 OK -Content-Length: 7 - -Hello -HTTP/1.1 200 OK -Transfer-Encoding: chunked - -1 -H -1 -e -2 -ll -1 -o -0 - - -]] -} - -function new(new) - - for _, stream in ipairs(httpstreams) do - local success; - local function success_cb(packet) - success = true; - end - stream = stream:gsub("\n", "\r\n"); - local parser = new(success_cb, error, stream:sub(1,4) == "HTTP" and "client" or "server") - for chunk in stream:gmatch("..?.?") do - parser:feed(chunk); - end - - assert_is(success); - end - -end diff --git a/tests/test_sasl.lua b/tests/test_sasl.lua deleted file mode 100644 index dd63c5a0..00000000 --- a/tests/test_sasl.lua +++ /dev/null @@ -1,38 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -local gmatch = string.gmatch; -local t_concat, t_insert = table.concat, table.insert; -local to_byte, to_char = string.byte, string.char; - -local function _latin1toutf8(str) - if not str then return str; end - local p = {}; - for ch in gmatch(str, ".") do - ch = to_byte(ch); - if (ch < 0x80) then - t_insert(p, to_char(ch)); - elseif (ch < 0xC0) then - t_insert(p, to_char(0xC2, ch)); - else - t_insert(p, to_char(0xC3, ch - 64)); - end - end - return t_concat(p); -end - -function latin1toutf8() - local function assert_utf8(latin, utf8) - assert_equal(_latin1toutf8(latin), utf8, "Incorrect UTF8 from Latin1: "..tostring(latin)); - end - - assert_utf8("", "") - assert_utf8("test", "test") - assert_utf8(nil, nil) - assert_utf8("foobar.r\229kat.se", "foobar.r\195\165kat.se") -end diff --git a/tests/test_utf8.lua b/tests/test_utf8.lua deleted file mode 100644 index 48859960..00000000 --- a/tests/test_utf8.lua +++ /dev/null @@ -1,18 +0,0 @@ -package.cpath = "../?.so" -package.path = "../?.lua"; - -function valid() - local encodings = require "util.encodings"; - local utf8 = assert(encodings.utf8, "no encodings.utf8 module"); - - for line in io.lines("utf8_sequences.txt") do - local data = line:match(":%s*([^#]+)"):gsub("%s+", ""):gsub("..", function (c) return string.char(tonumber(c, 16)); end) - local expect = line:match("(%S+):"); - if expect ~= "pass" and expect ~= "fail" then - error("unknown expectation: "..line:match("^[^:]+")); - end - local valid = utf8.valid(data); - assert_equal(valid, utf8.valid(data.." ")); - assert_equal(valid, expect == "pass", line); - end -end diff --git a/tests/test_util_cache.lua b/tests/test_util_cache.lua deleted file mode 100644 index 4240c433..00000000 --- a/tests/test_util_cache.lua +++ /dev/null @@ -1,309 +0,0 @@ -function new(new) - local c = new(5); - - local function expect_kv(key, value, actual_key, actual_value) - assert_equal(key, actual_key, "key incorrect"); - assert_equal(value, actual_value, "value incorrect"); - end - - expect_kv(nil, nil, c:head()); - expect_kv(nil, nil, c:tail()); - - assert_equal(c:count(), 0); - - c:set("one", 1) - assert_equal(c:count(), 1); - expect_kv("one", 1, c:head()); - expect_kv("one", 1, c:tail()); - - c:set("two", 2) - expect_kv("two", 2, c:head()); - expect_kv("one", 1, c:tail()); - - c:set("three", 3) - expect_kv("three", 3, c:head()); - expect_kv("one", 1, c:tail()); - - c:set("four", 4) - c:set("five", 5); - assert_equal(c:count(), 5); - expect_kv("five", 5, c:head()); - expect_kv("one", 1, c:tail()); - - c:set("foo", nil); - assert_equal(c:count(), 5); - expect_kv("five", 5, c:head()); - expect_kv("one", 1, c:tail()); - - assert_equal(c:get("one"), 1); - expect_kv("five", 5, c:head()); - expect_kv("one", 1, c:tail()); - - assert_equal(c:get("two"), 2); - assert_equal(c:get("three"), 3); - assert_equal(c:get("four"), 4); - assert_equal(c:get("five"), 5); - - assert_equal(c:get("foo"), nil); - assert_equal(c:get("bar"), nil); - - c:set("six", 6); - assert_equal(c:count(), 5); - expect_kv("six", 6, c:head()); - expect_kv("two", 2, c:tail()); - - assert_equal(c:get("one"), nil); - assert_equal(c:get("two"), 2); - assert_equal(c:get("three"), 3); - assert_equal(c:get("four"), 4); - assert_equal(c:get("five"), 5); - assert_equal(c:get("six"), 6); - - c:set("three", nil); - assert_equal(c:count(), 4); - - assert_equal(c:get("one"), nil); - assert_equal(c:get("two"), 2); - assert_equal(c:get("three"), nil); - assert_equal(c:get("four"), 4); - assert_equal(c:get("five"), 5); - assert_equal(c:get("six"), 6); - - c:set("seven", 7); - assert_equal(c:count(), 5); - - assert_equal(c:get("one"), nil); - assert_equal(c:get("two"), 2); - assert_equal(c:get("three"), nil); - assert_equal(c:get("four"), 4); - assert_equal(c:get("five"), 5); - assert_equal(c:get("six"), 6); - assert_equal(c:get("seven"), 7); - - c:set("eight", 8); - assert_equal(c:count(), 5); - - assert_equal(c:get("one"), nil); - assert_equal(c:get("two"), nil); - assert_equal(c:get("three"), nil); - assert_equal(c:get("four"), 4); - assert_equal(c:get("five"), 5); - assert_equal(c:get("six"), 6); - assert_equal(c:get("seven"), 7); - assert_equal(c:get("eight"), 8); - - c:set("four", 4); - assert_equal(c:count(), 5); - - assert_equal(c:get("one"), nil); - assert_equal(c:get("two"), nil); - assert_equal(c:get("three"), nil); - assert_equal(c:get("four"), 4); - assert_equal(c:get("five"), 5); - assert_equal(c:get("six"), 6); - assert_equal(c:get("seven"), 7); - assert_equal(c:get("eight"), 8); - - c:set("nine", 9); - assert_equal(c:count(), 5); - - assert_equal(c:get("one"), nil); - assert_equal(c:get("two"), nil); - assert_equal(c:get("three"), nil); - assert_equal(c:get("four"), 4); - assert_equal(c:get("five"), nil); - assert_equal(c:get("six"), 6); - assert_equal(c:get("seven"), 7); - assert_equal(c:get("eight"), 8); - assert_equal(c:get("nine"), 9); - - do - local keys = { "nine", "four", "eight", "seven", "six" }; - local values = { 9, 4, 8, 7, 6 }; - local i = 0; - for k, v in c:items() do - i = i + 1; - assert_equal(k, keys[i]); - assert_equal(v, values[i]); - end - assert_equal(i, 5); - - c:set("four", "2+2"); - assert_equal(c:count(), 5); - - assert_equal(c:get("one"), nil); - assert_equal(c:get("two"), nil); - assert_equal(c:get("three"), nil); - assert_equal(c:get("four"), "2+2"); - assert_equal(c:get("five"), nil); - assert_equal(c:get("six"), 6); - assert_equal(c:get("seven"), 7); - assert_equal(c:get("eight"), 8); - assert_equal(c:get("nine"), 9); - end - - do - local keys = { "four", "nine", "eight", "seven", "six" }; - local values = { "2+2", 9, 8, 7, 6 }; - local i = 0; - for k, v in c:items() do - i = i + 1; - assert_equal(k, keys[i]); - assert_equal(v, values[i]); - end - assert_equal(i, 5); - - c:set("foo", nil); - assert_equal(c:count(), 5); - - assert_equal(c:get("one"), nil); - assert_equal(c:get("two"), nil); - assert_equal(c:get("three"), nil); - assert_equal(c:get("four"), "2+2"); - assert_equal(c:get("five"), nil); - assert_equal(c:get("six"), 6); - assert_equal(c:get("seven"), 7); - assert_equal(c:get("eight"), 8); - assert_equal(c:get("nine"), 9); - end - - do - local keys = { "four", "nine", "eight", "seven", "six" }; - local values = { "2+2", 9, 8, 7, 6 }; - local i = 0; - for k, v in c:items() do - i = i + 1; - assert_equal(k, keys[i]); - assert_equal(v, values[i]); - end - assert_equal(i, 5); - - c:set("four", nil); - - assert_equal(c:get("one"), nil); - assert_equal(c:get("two"), nil); - assert_equal(c:get("three"), nil); - assert_equal(c:get("four"), nil); - assert_equal(c:get("five"), nil); - assert_equal(c:get("six"), 6); - assert_equal(c:get("seven"), 7); - assert_equal(c:get("eight"), 8); - assert_equal(c:get("nine"), 9); - end - - do - local keys = { "nine", "eight", "seven", "six" }; - local values = { 9, 8, 7, 6 }; - local i = 0; - for k, v in c:items() do - i = i + 1; - assert_equal(k, keys[i]); - assert_equal(v, values[i]); - end - assert_equal(i, 4); - end - - do - local evicted_key, evicted_value; - local c2 = new(3, function (_key, _value) - evicted_key, evicted_value = _key, _value; - end); - local function set(k, v, should_evict_key, should_evict_value) - evicted_key, evicted_value = nil, nil; - c2:set(k, v); - assert_equal(evicted_key, should_evict_key); - assert_equal(evicted_value, should_evict_value); - end - set("a", 1) - set("a", 1) - set("a", 1) - set("a", 1) - set("a", 1) - - set("b", 2) - set("c", 3) - set("b", 2) - set("d", 4, "a", 1) - set("e", 5, "c", 3) - end - - do - local evicted_key, evicted_value; - local c3 = new(1, function (_key, _value) - evicted_key, evicted_value = _key, _value; - if _key == "a" then - -- Sanity check for what we're evicting - assert_equal(_key, "a"); - assert_equal(_value, 1); - -- We're going to block eviction of this key/value, so set to nil... - evicted_key, evicted_value = nil, nil; - -- Returning false to block eviction - return false - end - end); - local function set(k, v, should_evict_key, should_evict_value) - evicted_key, evicted_value = nil, nil; - local ret = c3:set(k, v); - assert_equal(evicted_key, should_evict_key); - assert_equal(evicted_value, should_evict_value); - return ret; - end - set("a", 1) - set("a", 1) - set("a", 1) - set("a", 1) - set("a", 1) - - -- Our on_evict prevents "a" from being evicted, causing this to fail... - assert_equal(set("b", 2), false, "Failed to prevent eviction, or signal result"); - - expect_kv("a", 1, c3:head()); - expect_kv("a", 1, c3:tail()); - - -- Check the final state is what we expect - assert_equal(c3:get("a"), 1); - assert_equal(c3:get("b"), nil); - assert_equal(c3:count(), 1); - end - - - local c4 = new(3, false); - - assert_equal(c4:set("a", 1), true); - assert_equal(c4:set("a", 1), true); - assert_equal(c4:set("a", 1), true); - assert_equal(c4:set("a", 1), true); - assert_equal(c4:set("b", 2), true); - assert_equal(c4:set("c", 3), true); - assert_equal(c4:set("d", 4), false); - assert_equal(c4:set("d", 4), false); - assert_equal(c4:set("d", 4), false); - - expect_kv("c", 3, c4:head()); - expect_kv("a", 1, c4:tail()); - - local c5 = new(3, function (k, v) - if k == "a" then - return nil; - elseif k == "b" then - return true; - end - return false; - end); - - assert_equal(c5:set("a", 1), true); - assert_equal(c5:set("a", 1), true); - assert_equal(c5:set("a", 1), true); - assert_equal(c5:set("a", 1), true); - assert_equal(c5:set("b", 2), true); - assert_equal(c5:set("c", 3), true); - assert_equal(c5:set("d", 4), true); -- "a" evicted (cb returned nil) - assert_equal(c5:set("d", 4), true); -- nop - assert_equal(c5:set("d", 4), true); -- nop - assert_equal(c5:set("e", 5), true); -- "b" evicted (cb returned true) - assert_equal(c5:set("f", 6), false); -- "c" won't evict (cb returned false) - - expect_kv("e", 5, c5:head()); - expect_kv("c", 3, c5:tail()); - -end diff --git a/tests/test_util_http.lua b/tests/test_util_http.lua deleted file mode 100644 index d9cc2779..00000000 --- a/tests/test_util_http.lua +++ /dev/null @@ -1,41 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -function urlencode(urlencode) - assert_equal(urlencode("helloworld123"), "helloworld123", "Normal characters not escaped"); - assert_equal(urlencode("hello world"), "hello%20world", "Spaces escaped"); - assert_equal(urlencode("This & that = something"), "This%20%26%20that%20%3d%20something", "Important URL chars escaped"); -end - -function urldecode(urldecode) - assert_equal("helloworld123", urldecode("helloworld123"), "Normal characters not escaped"); - assert_equal("hello world", urldecode("hello%20world"), "Spaces escaped"); - assert_equal("This & that = something", urldecode("This%20%26%20that%20%3d%20something"), "Important URL chars escaped"); - assert_equal("This & that = something", urldecode("This%20%26%20that%20%3D%20something"), "Important URL chars escaped"); -end - -function formencode(formencode) - assert_equal(formencode({ { name = "one", value = "1"}, { name = "two", value = "2" } }), "one=1&two=2", "Form encoded"); - assert_equal(formencode({ { name = "one two", value = "1"}, { name = "two one&", value = "2" } }), "one+two=1&two+one%26=2", "Form encoded"); -end - -function formdecode(formdecode) - do - local t = formdecode("one=1&two=2"); - assert_table(t[1]); - assert_equal(t[1].name, "one"); assert_equal(t[1].value, "1"); - assert_table(t[2]); - assert_equal(t[2].name, "two"); assert_equal(t[2].value, "2"); - end - - do - local t = formdecode("one+two=1&two+one%26=2"); - assert_equal(t[1].name, "one two"); assert_equal(t[1].value, "1"); - assert_equal(t[2].name, "two one&"); assert_equal(t[2].value, "2"); - end -end diff --git a/tests/test_util_ip.lua b/tests/test_util_ip.lua deleted file mode 100644 index 0ded1123..00000000 --- a/tests/test_util_ip.lua +++ /dev/null @@ -1,89 +0,0 @@ - -function match(match, _M) - local _ = _M.new_ip; - local ip = _"10.20.30.40"; - assert_equal(match(ip, _"10.0.0.0", 8), true); - assert_equal(match(ip, _"10.0.0.0", 16), false); - assert_equal(match(ip, _"10.0.0.0", 24), false); - assert_equal(match(ip, _"10.0.0.0", 32), false); - - assert_equal(match(ip, _"10.20.0.0", 8), true); - assert_equal(match(ip, _"10.20.0.0", 16), true); - assert_equal(match(ip, _"10.20.0.0", 24), false); - assert_equal(match(ip, _"10.20.0.0", 32), false); - - assert_equal(match(ip, _"0.0.0.0", 32), false); - assert_equal(match(ip, _"0.0.0.0", 0), true); - assert_equal(match(ip, _"0.0.0.0"), false); - - assert_equal(match(ip, _"10.0.0.0", 255), false, "excessive number of bits"); - assert_equal(match(ip, _"10.0.0.0", -8), true, "negative number of bits"); - assert_equal(match(ip, _"10.0.0.0", -32), true, "negative number of bits"); - assert_equal(match(ip, _"10.0.0.0", 0), true, "zero bits"); - assert_equal(match(ip, _"10.0.0.0"), false, "no specified number of bits (differing ip)"); - assert_equal(match(ip, _"10.20.30.40"), true, "no specified number of bits (same ip)"); - - assert_equal(match(_"127.0.0.1", _"127.0.0.1"), true, "simple ip"); - - assert_equal(match(_"8.8.8.8", _"8.8.0.0", 16), true); - assert_equal(match(_"8.8.4.4", _"8.8.0.0", 16), true); -end - -function parse_cidr(parse_cidr, _M) - local new_ip = _M.new_ip; - - assert_equal(new_ip"0.0.0.0", new_ip"0.0.0.0") - - local function assert_cidr(cidr, ip, bits) - local parsed_ip, parsed_bits = parse_cidr(cidr); - assert_equal(new_ip(ip), parsed_ip, cidr.." parsed ip is "..ip); - assert_equal(bits, parsed_bits, cidr.." parsed bits is "..tostring(bits)); - end - assert_cidr("0.0.0.0", "0.0.0.0", nil); - assert_cidr("127.0.0.1", "127.0.0.1", nil); - assert_cidr("127.0.0.1/0", "127.0.0.1", 0); - assert_cidr("127.0.0.1/8", "127.0.0.1", 8); - assert_cidr("127.0.0.1/32", "127.0.0.1", 32); - assert_cidr("127.0.0.1/256", "127.0.0.1", 256); - assert_cidr("::/48", "::", 48); -end - -function new_ip(new_ip) - local v4, v6 = "IPv4", "IPv6"; - local function assert_proto(s, proto) - local ip = new_ip(s); - if proto then - assert_equal(ip and ip.proto, proto, "protocol is correct for "..("%q"):format(s)); - else - assert_equal(ip, nil, "address is invalid"); - end - end - assert_proto("127.0.0.1", v4); - assert_proto("::1", v6); - assert_proto("", nil); - assert_proto("abc", nil); - assert_proto(" ", nil); -end - -function commonPrefixLength(cpl, _M) - local new_ip = _M.new_ip; - local function assert_cpl6(a, b, len, v4) - local ipa, ipb = new_ip(a), new_ip(b); - if v4 then len = len+96; end - assert_equal(cpl(ipa, ipb), len, "common prefix length of "..a.." and "..b.." is "..len); - assert_equal(cpl(ipb, ipa), len, "common prefix length of "..b.." and "..a.." is "..len); - end - local function assert_cpl4(a, b, len) - return assert_cpl6(a, b, len, "IPv4"); - end - assert_cpl4("0.0.0.0", "0.0.0.0", 32); - assert_cpl4("255.255.255.255", "0.0.0.0", 0); - assert_cpl4("255.255.255.255", "255.255.0.0", 16); - assert_cpl4("255.255.255.255", "255.255.255.255", 32); - assert_cpl4("255.255.255.255", "255.255.255.255", 32); - - assert_cpl6("::1", "::1", 128); - assert_cpl6("abcd::1", "abcd::1", 128); - assert_cpl6("abcd::abcd", "abcd::", 112); - assert_cpl6("abcd::abcd", "abcd::abcd:abcd", 96); -end diff --git a/tests/test_util_jid.lua b/tests/test_util_jid.lua deleted file mode 100644 index 0ac5827e..00000000 --- a/tests/test_util_jid.lua +++ /dev/null @@ -1,143 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -function join(join) - assert_equal(join("a", "b", "c"), "a@b/c", "builds full JID"); - assert_equal(join("a", "b", nil), "a@b", "builds bare JID"); - assert_equal(join(nil, "b", "c"), "b/c", "builds full host JID"); - assert_equal(join(nil, "b", nil), "b", "builds bare host JID"); - assert_equal(join(nil, nil, nil), nil, "invalid JID is nil"); - assert_equal(join("a", nil, nil), nil, "invalid JID is nil"); - assert_equal(join(nil, nil, "c"), nil, "invalid JID is nil"); - assert_equal(join("a", nil, "c"), nil, "invalid JID is nil"); -end - - -function split(split) - local function test(input_jid, expected_node, expected_server, expected_resource) - local rnode, rserver, rresource = split(input_jid); - assert_equal(expected_node, rnode, "split("..tostring(input_jid)..") failed"); - assert_equal(expected_server, rserver, "split("..tostring(input_jid)..") failed"); - assert_equal(expected_resource, rresource, "split("..tostring(input_jid)..") failed"); - end - - -- Valid JIDs - test("node@server", "node", "server", nil ); - test("node@server/resource", "node", "server", "resource" ); - test("server", nil, "server", nil ); - test("server/resource", nil, "server", "resource" ); - test("server/resource@foo", nil, "server", "resource@foo" ); - test("server/resource@foo/bar", nil, "server", "resource@foo/bar"); - - -- Always invalid JIDs - test(nil, nil, nil, nil); - test("node@/server", nil, nil, nil); - test("@server", nil, nil, nil); - test("@server/resource", nil, nil, nil); - test("@/resource", nil, nil, nil); -end - -function bare(bare) - assert_equal(bare("user@host"), "user@host", "bare JID remains bare"); - assert_equal(bare("host"), "host", "Host JID remains host"); - assert_equal(bare("host/resource"), "host", "Host JID with resource becomes host"); - assert_equal(bare("user@host/resource"), "user@host", "user@host JID with resource becomes user@host"); - assert_equal(bare("user@/resource"), nil, "invalid JID is nil"); - assert_equal(bare("@/resource"), nil, "invalid JID is nil"); - assert_equal(bare("@/"), nil, "invalid JID is nil"); - assert_equal(bare("/"), nil, "invalid JID is nil"); - assert_equal(bare(""), nil, "invalid JID is nil"); - assert_equal(bare("@"), nil, "invalid JID is nil"); - assert_equal(bare("user@"), nil, "invalid JID is nil"); - assert_equal(bare("user@@"), nil, "invalid JID is nil"); - assert_equal(bare("user@@host"), nil, "invalid JID is nil"); - assert_equal(bare("user@@host/resource"), nil, "invalid JID is nil"); - assert_equal(bare("user@host/"), nil, "invalid JID is nil"); -end - -function compare(compare) - assert_equal(compare("host", "host"), true, "host should match"); - assert_equal(compare("host", "other-host"), false, "host should not match"); - assert_equal(compare("other-user@host/resource", "host"), true, "host should match"); - assert_equal(compare("other-user@host", "user@host"), false, "user should not match"); - assert_equal(compare("user@host", "host"), true, "host should match"); - assert_equal(compare("user@host/resource", "host"), true, "host should match"); - assert_equal(compare("user@host/resource", "user@host"), true, "user and host should match"); - assert_equal(compare("user@other-host", "host"), false, "host should not match"); - assert_equal(compare("user@other-host", "user@host"), false, "host should not match"); -end - -function node(node) - local function test(jid, expected_node) - assert_equal(node(jid), expected_node, "Unexpected node for "..tostring(jid)); - end - - test("example.com", nil); - test("foo.example.com", nil); - test("foo.example.com/resource", nil); - test("foo.example.com/some resource", nil); - test("foo.example.com/some@resource", nil); - - test("foo@foo.example.com/some@resource", "foo"); - test("foo@example/some@resource", "foo"); - - test("foo@example/@resource", "foo"); - test("foo@example@resource", nil); - test("foo@example", "foo"); - test("foo", nil); - - test(nil, nil); -end - -function host(host) - local function test(jid, expected_host) - assert_equal(host(jid), expected_host, "Unexpected host for "..tostring(jid)); - end - - test("example.com", "example.com"); - test("foo.example.com", "foo.example.com"); - test("foo.example.com/resource", "foo.example.com"); - test("foo.example.com/some resource", "foo.example.com"); - test("foo.example.com/some@resource", "foo.example.com"); - - test("foo@foo.example.com/some@resource", "foo.example.com"); - test("foo@example/some@resource", "example"); - - test("foo@example/@resource", "example"); - test("foo@example@resource", nil); - test("foo@example", "example"); - test("foo", "foo"); - - test(nil, nil); -end - -function resource(resource) - local function test(jid, expected_resource) - assert_equal(resource(jid), expected_resource, "Unexpected resource for "..tostring(jid)); - end - - test("example.com", nil); - test("foo.example.com", nil); - test("foo.example.com/resource", "resource"); - test("foo.example.com/some resource", "some resource"); - test("foo.example.com/some@resource", "some@resource"); - - test("foo@foo.example.com/some@resource", "some@resource"); - test("foo@example/some@resource", "some@resource"); - - test("foo@example/@resource", "@resource"); - test("foo@example@resource", nil); - test("foo@example", nil); - test("foo", nil); - test("/foo", nil); - test("@x/foo", nil); - test("@/foo", nil); - - test(nil, nil); -end - diff --git a/tests/test_util_json.lua b/tests/test_util_json.lua deleted file mode 100644 index 2c1a9ce9..00000000 --- a/tests/test_util_json.lua +++ /dev/null @@ -1,21 +0,0 @@ - -function encode(encode, json) - local function test(f, j, e) - if e then - assert_equal(f(j), e); - end - assert_equal(f(j), f(json.decode(f(j)))); - end - test(encode, json.null, "null") - test(encode, {}, "{}") - test(encode, {a=1}); - test(encode, {a={1,2,3}}); - test(encode, {1}, "[1]"); -end - -function decode(decode) - local empty_array = decode("[]"); - assert_equal(type(empty_array), "table"); - assert_equal(#empty_array, 0); - assert_equal(next(empty_array), nil); -end diff --git a/tests/test_util_json.sh b/tests/test_util_json.sh deleted file mode 100755 index bbbd132b..00000000 --- a/tests/test_util_json.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash - -export LUA_PATH="../?.lua;;" -export LUA_CPATH="../?.so;;" - -#set -x - -if ! which "$RUNWITH"; then - echo "Unable to find interpreter $RUNWITH"; - exit 1; -fi - -if ! $RUNWITH -e 'assert(require"util.json")' 2>/dev/null; then - echo "Unable to find util.json"; - exit 1; -fi - -FAIL=0 - -for f in json/pass*.json; do - if ! $RUNWITH -e 'local j=require"util.json" assert(j.decode(io.read("*a"))~=nil)' <"$f" 2>/dev/null; then - echo "Failed to decode valid JSON: $f"; - FAIL=1 - fi -done - -for f in json/fail*.json; do - if ! $RUNWITH -e 'local j=require"util.json" assert(j.decode(io.read("*a"))==nil)' <"$f" 2>/dev/null; then - echo "Invalid JSON decoded without error: $f"; - FAIL=1 - fi -done - -if [ "$FAIL" == "1" ]; then - echo "JSON tests failed" - exit 1; -fi - -exit 0; diff --git a/tests/test_util_multitable.lua b/tests/test_util_multitable.lua deleted file mode 100644 index 45727bc3..00000000 --- a/tests/test_util_multitable.lua +++ /dev/null @@ -1,62 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -function new(new, multitable) - local mt = new(); - assert_table(mt, "Multitable is a table"); - assert_function(mt.add, "Multitable has method add"); - assert_function(mt.get, "Multitable has method get"); - assert_function(mt.remove, "Multitable has method remove"); - - get(mt.get, multitable); -end - -function get(get, multitable) - local function has_items(list, ...) - local should_have = {}; - if select('#', ...) > 0 then - assert_table(list, "has_items: list is table", 3); - else - assert_is_not(list and #list > 0, "No items, and no list"); - return true, "has-all"; - end - for n=1,select('#', ...) do should_have[select(n, ...)] = true; end - for _, item in ipairs(list) do - if not should_have[item] then return false, "too-many"; end - should_have[item] = nil; - end - if next(should_have) then - return false, "not-enough"; - end - return true, "has-all"; - end - local function assert_has_all(message, list, ...) - return assert_equal(select(2, has_items(list, ...)), "has-all", message or "List has all expected items, and no more", 2); - end - - local mt = multitable.new(); - - local trigger1, trigger2, trigger3 = {}, {}, {}; - local item1, item2, item3 = {}, {}, {}; - - assert_has_all("Has no items with trigger1", mt:get(trigger1)); - - - mt:add(1, 2, 3, item1); - - assert_has_all("Has item1 for 1, 2, 3", mt:get(1, 2, 3), item1); - --- Doesn't support nil ---[[ mt:add(nil, item1); - mt:add(nil, item2); - mt:add(nil, item3); - - assert_has_all("Has all items with (nil)", mt:get(nil), item1, item2, item3); -]] -end diff --git a/tests/test_util_queue.lua b/tests/test_util_queue.lua deleted file mode 100644 index e215ba33..00000000 --- a/tests/test_util_queue.lua +++ /dev/null @@ -1,74 +0,0 @@ - -function new(new) - do - local q = new(10); - - assert_equal(q.size, 10); - assert_equal(q:count(), 0); - - assert_is(q:push("one")); - assert_is(q:push("two")); - assert_is(q:push("three")); - - for i = 4, 10 do - assert_is(q:push("hello")); - assert_equal(q:count(), i, "count is not "..i.."("..q:count()..")"); - end - assert_equal(q:push("hello"), nil, "queue overfull!"); - assert_equal(q:push("hello"), nil, "queue overfull!"); - assert_equal(q:pop(), "one", "queue item incorrect"); - assert_equal(q:pop(), "two", "queue item incorrect"); - assert_is(q:push("hello")); - assert_is(q:push("hello")); - assert_equal(q:pop(), "three", "queue item incorrect"); - assert_is(q:push("hello")); - assert_equal(q:push("hello"), nil, "queue overfull!"); - assert_equal(q:push("hello"), nil, "queue overfull!"); - - assert_equal(q:count(), 10, "queue count incorrect"); - - for _ = 1, 10 do - assert_equal(q:pop(), "hello", "queue item incorrect"); - end - - assert_equal(q:count(), 0, "queue count incorrect"); - - assert_is(q:push(1)); - for i = 1, 1001 do - assert_equal(q:pop(), i); - assert_equal(q:count(), 0); - assert_is(q:push(i+1)); - assert_equal(q:count(), 1); - end - assert_equal(q:pop(), 1002); - assert_is(q:push(1)); - for i = 1, 1000 do - assert_equal(q:pop(), i); - assert_is(q:push(i+1)); - end - assert_equal(q:pop(), 1001); - assert_equal(q:count(), 0); - end - - do - -- Test queues that purge old items when pushing to a full queue - local q = new(10, true); - - for i = 1, 10 do - q:push(i); - end - - assert_equal(q:count(), 10); - - assert_is(q:push(11)); - assert_equal(q:count(), 10); - assert_equal(q:pop(), 2); -- First item should have been purged - - for i = 12, 32 do - assert_is(q:push(i)); - end - - assert_equal(q:count(), 10); - assert_equal(q:pop(), 23); - end -end diff --git a/tests/test_util_random.lua b/tests/test_util_random.lua deleted file mode 100644 index 79572ef8..00000000 --- a/tests/test_util_random.lua +++ /dev/null @@ -1,10 +0,0 @@ --- Makes no attempt at testing how random the bytes are, --- just that it returns the number of bytes requested - -function bytes(bytes) - assert_is(bytes(16)); - - for i = 1, 255 do - assert_equal(i, #bytes(i)); - end -end diff --git a/tests/test_util_rfc6724.lua b/tests/test_util_rfc6724.lua deleted file mode 100644 index bb73e921..00000000 --- a/tests/test_util_rfc6724.lua +++ /dev/null @@ -1,97 +0,0 @@ --- Prosody IM --- Copyright (C) 2011-2013 Florian Zeitz --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -function source(source) - local new_ip = require"util.ip".new_ip; - assert_equal(source(new_ip("2001:db8:1::1", "IPv6"), - {new_ip("2001:db8:3::1", "IPv6"), new_ip("fe80::1", "IPv6")}).addr, - "2001:db8:3::1", - "prefer appropriate scope"); - assert_equal(source(new_ip("ff05::1", "IPv6"), - {new_ip("2001:db8:3::1", "IPv6"), new_ip("fe80::1", "IPv6")}).addr, - "2001:db8:3::1", - "prefer appropriate scope"); - assert_equal(source(new_ip("2001:db8:1::1", "IPv6"), - {new_ip("2001:db8:1::1", "IPv6"), new_ip("2001:db8:2::1", "IPv6")}).addr, - "2001:db8:1::1", - "prefer same address"); -- "2001:db8:1::1" should be marked "deprecated" here, we don't handle that right now - assert_equal(source(new_ip("fe80::1", "IPv6"), - {new_ip("fe80::2", "IPv6"), new_ip("2001:db8:1::1", "IPv6")}).addr, - "fe80::2", - "prefer appropriate scope"); -- "fe80::2" should be marked "deprecated" here, we don't handle that right now - assert_equal(source(new_ip("2001:db8:1::1", "IPv6"), - {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3::2", "IPv6")}).addr, - "2001:db8:1::2", - "longest matching prefix"); ---[[ "2001:db8:1::2" should be a care-of address and "2001:db8:3::2" a home address, we can't handle this and would fail - assert_equal(source(new_ip("2001:db8:1::1", "IPv6"), - {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3::2", "IPv6")}).addr, - "2001:db8:3::2", - "prefer home address"); -]] - assert_equal(source(new_ip("2002:c633:6401::1", "IPv6"), - {new_ip("2002:c633:6401::d5e3:7953:13eb:22e8", "IPv6"), new_ip("2001:db8:1::2", "IPv6")}).addr, - "2002:c633:6401::d5e3:7953:13eb:22e8", - "prefer matching label"); -- "2002:c633:6401::d5e3:7953:13eb:22e8" should be marked "temporary" here, we don't handle that right now - assert_equal(source(new_ip("2001:db8:1::d5e3:0:0:1", "IPv6"), - {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:1::d5e3:7953:13eb:22e8", "IPv6")}).addr, - "2001:db8:1::d5e3:7953:13eb:22e8", - "prefer temporary address") -- "2001:db8:1::2" should be marked "public" and "2001:db8:1::d5e3:7953:13eb:22e8" should be marked "temporary" here, we don't handle that right now -end - -function destination(dest) - local order; - local new_ip = require"util.ip".new_ip; - order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("198.51.100.121", "IPv4")}, - {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::1", "IPv6"), new_ip("169.254.13.78", "IPv4")}) - assert_equal(order[1].addr, "2001:db8:1::1", "prefer matching scope"); - assert_equal(order[2].addr, "198.51.100.121", "prefer matching scope"); - - order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("198.51.100.121", "IPv4")}, - {new_ip("fe80::1", "IPv6"), new_ip("198.51.100.117", "IPv4")}) - assert_equal(order[1].addr, "198.51.100.121", "prefer matching scope"); - assert_equal(order[2].addr, "2001:db8:1::1", "prefer matching scope"); - - order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("10.1.2.3", "IPv4")}, - {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::1", "IPv6"), new_ip("10.1.2.4", "IPv4")}) - assert_equal(order[1].addr, "2001:db8:1::1", "prefer higher precedence"); - assert_equal(order[2].addr, "10.1.2.3", "prefer higher precedence"); - - order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("fe80::1", "IPv6")}, - {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::2", "IPv6")}) - assert_equal(order[1].addr, "fe80::1", "prefer smaller scope"); - assert_equal(order[2].addr, "2001:db8:1::1", "prefer smaller scope"); - ---[[ "2001:db8:1::2" and "fe80::2" should be marked "care-of address", while "2001:db8:3::1" should be marked "home address", we can't currently handle this and would fail the test - order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("fe80::1", "IPv6")}, - {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3::1", "IPv6"), new_ip("fe80::2", "IPv6")}) - assert_equal(order[1].addr, "2001:db8:1::1", "prefer home address"); - assert_equal(order[2].addr, "fe80::1", "prefer home address"); -]] - ---[[ "fe80::2" should be marked "deprecated", we can't currently handle this and would fail the test - order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("fe80::1", "IPv6")}, - {new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::2", "IPv6")}) - assert_equal(order[1].addr, "2001:db8:1::1", "avoid deprecated addresses"); - assert_equal(order[2].addr, "fe80::1", "avoid deprecated addresses"); -]] - - order = dest({new_ip("2001:db8:1::1", "IPv6"), new_ip("2001:db8:3ffe::1", "IPv6")}, - {new_ip("2001:db8:1::2", "IPv6"), new_ip("2001:db8:3f44::2", "IPv6"), new_ip("fe80::2", "IPv6")}) - assert_equal(order[1].addr, "2001:db8:1::1", "longest matching prefix"); - assert_equal(order[2].addr, "2001:db8:3ffe::1", "longest matching prefix"); - - order = dest({new_ip("2002:c633:6401::1", "IPv6"), new_ip("2001:db8:1::1", "IPv6")}, - {new_ip("2002:c633:6401::2", "IPv6"), new_ip("fe80::2", "IPv6")}) - assert_equal(order[1].addr, "2002:c633:6401::1", "prefer matching label"); - assert_equal(order[2].addr, "2001:db8:1::1", "prefer matching label"); - - order = dest({new_ip("2002:c633:6401::1", "IPv6"), new_ip("2001:db8:1::1", "IPv6")}, - {new_ip("2002:c633:6401::2", "IPv6"), new_ip("2001:db8:1::2", "IPv6"), new_ip("fe80::2", "IPv6")}) - assert_equal(order[1].addr, "2001:db8:1::1", "prefer higher precedence"); - assert_equal(order[2].addr, "2002:c633:6401::1", "prefer higher precedence"); -end diff --git a/tests/test_util_sasl_scram.lua b/tests/test_util_sasl_scram.lua deleted file mode 100644 index bc89829f..00000000 --- a/tests/test_util_sasl_scram.lua +++ /dev/null @@ -1,23 +0,0 @@ - - -local hmac_sha1 = require "util.hashes".hmac_sha1; -local function toHex(s) - return s and (s:gsub(".", function (c) return ("%02x"):format(c:byte()); end)); -end - -function Hi(Hi) - assert( toHex(Hi(hmac_sha1, "password", "salt", 1)) == "0c60c80f961f0e71f3a9b524af6012062fe037a6", - [[FAIL: toHex(Hi(hmac_sha1, "password", "salt", 1)) == "0c60c80f961f0e71f3a9b524af6012062fe037a6"]]) - assert( toHex(Hi(hmac_sha1, "password", "salt", 2)) == "ea6c014dc72d6f8ccd1ed92ace1d41f0d8de8957", - [[FAIL: toHex(Hi(hmac_sha1, "password", "salt", 2)) == "ea6c014dc72d6f8ccd1ed92ace1d41f0d8de8957"]]) - assert( toHex(Hi(hmac_sha1, "password", "salt", 64)) == "a7bc9b6efea2cbd717da72d83bfcc4e17d0b6280", - [[FAIL: toHex(Hi(hmac_sha1, "password", "salt", 64)) == "a7bc9b6efea2cbd717da72d83bfcc4e17d0b6280"]]) - assert( toHex(Hi(hmac_sha1, "password", "salt", 4096)) == "4b007901b765489abead49d926f721d065a429c1", - [[FAIL: toHex(Hi(hmac_sha1, "password", "salt", 4096)) == "4b007901b765489abead49d926f721d065a429c1"]]) - -- assert( toHex(Hi(hmac_sha1, "password", "salt", 16777216)) == "eefe3d61cd4da4e4e9945b3d6ba2158c2634e984", - -- [[FAIL: toHex(Hi(hmac_sha1, "password", "salt", 16777216)) == "eefe3d61cd4da4e4e9945b3d6ba2158c2634e984"]]) -end - -function init(init) - -- no tests -end diff --git a/tests/test_util_stanza.lua b/tests/test_util_stanza.lua deleted file mode 100644 index 4be07a4b..00000000 --- a/tests/test_util_stanza.lua +++ /dev/null @@ -1,152 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - - -function preserialize(preserialize, st) - local stanza = st.stanza("message", { a = "a" }); - local stanza2 = preserialize(stanza); - assert_is(stanza2 and stanza.name, "preserialize returns a stanza"); - assert_is_not(stanza2.tags, "Preserialized stanza has no tag list"); - assert_is_not(stanza2.last_add, "Preserialized stanza has no last_add marker"); - assert_is_not(getmetatable(stanza2), "Preserialized stanza has no metatable"); -end - -function deserialize(deserialize, st) - local stanza = st.stanza("message", { a = "a" }); - - local stanza2 = deserialize(st.preserialize(stanza)); - assert_is(stanza2 and stanza.name, "deserialize returns a stanza"); - assert_table(stanza2.attr, "Deserialized stanza has attributes"); - assert_equal(stanza2.attr.a, "a", "Deserialized stanza retains attributes"); - assert_table(getmetatable(stanza2), "Deserialized stanza has metatable"); -end - -function stanza(stanza) - local s = stanza("foo", { xmlns = "myxmlns", a = "attr-a" }); - assert_equal(s.name, "foo"); - assert_equal(s.attr.xmlns, "myxmlns"); - assert_equal(s.attr.a, "attr-a"); - - local s1 = stanza("s1"); - assert_equal(s1.name, "s1"); - assert_equal(s1.attr.xmlns, nil); - assert_equal(#s1, 0); - assert_equal(#s1.tags, 0); - - s1:tag("child1"); - assert_equal(#s1.tags, 1); - assert_equal(s1.tags[1].name, "child1"); - - s1:tag("grandchild1"):up(); - assert_equal(#s1.tags, 1); - assert_equal(s1.tags[1].name, "child1"); - assert_equal(#s1.tags[1], 1); - assert_equal(s1.tags[1][1].name, "grandchild1"); - - s1:up():tag("child2"); - assert_equal(#s1.tags, 2, tostring(s1)); - assert_equal(s1.tags[1].name, "child1"); - assert_equal(s1.tags[2].name, "child2"); - assert_equal(#s1.tags[1], 1); - assert_equal(s1.tags[1][1].name, "grandchild1"); - - s1:up():text("Hello world"); - assert_equal(#s1.tags, 2); - assert_equal(#s1, 3); - assert_equal(s1.tags[1].name, "child1"); - assert_equal(s1.tags[2].name, "child2"); - assert_equal(#s1.tags[1], 1); - assert_equal(s1.tags[1][1].name, "grandchild1"); -end - -function message(message) - local m = message(); - assert_equal(m.name, "message"); -end - -function iq(iq) - local i = iq(); - assert_equal(i.name, "iq"); -end - -function presence(presence) - local p = presence(); - assert_equal(p.name, "presence"); -end - -function reply(reply, _M) - do - -- Test stanza - local s = _M.stanza("s", { to = "touser", from = "fromuser", id = "123" }) - :tag("child1"); - -- Make reply stanza - local r = reply(s); - assert_equal(r.name, s.name); - assert_equal(r.id, s.id); - assert_equal(r.attr.to, s.attr.from); - assert_equal(r.attr.from, s.attr.to); - assert_equal(#r.tags, 0, "A reply should not include children of the original stanza"); - end - - do - -- Test stanza - local s = _M.stanza("iq", { to = "touser", from = "fromuser", id = "123", type = "get" }) - :tag("child1"); - -- Make reply stanza - local r = reply(s); - assert_equal(r.name, s.name); - assert_equal(r.id, s.id); - assert_equal(r.attr.to, s.attr.from); - assert_equal(r.attr.from, s.attr.to); - assert_equal(r.attr.type, "result"); - assert_equal(#r.tags, 0, "A reply should not include children of the original stanza"); - end - - do - -- Test stanza - local s = _M.stanza("iq", { to = "touser", from = "fromuser", id = "123", type = "set" }) - :tag("child1"); - -- Make reply stanza - local r = reply(s); - assert_equal(r.name, s.name); - assert_equal(r.id, s.id); - assert_equal(r.attr.to, s.attr.from); - assert_equal(r.attr.from, s.attr.to); - assert_equal(r.attr.type, "result"); - assert_equal(#r.tags, 0, "A reply should not include children of the original stanza"); - end -end - -function error_reply(error_reply, _M) - do - -- Test stanza - local s = _M.stanza("s", { to = "touser", from = "fromuser", id = "123" }) - :tag("child1"); - -- Make reply stanza - local r = error_reply(s); - assert_equal(r.name, s.name); - assert_equal(r.id, s.id); - assert_equal(r.attr.to, s.attr.from); - assert_equal(r.attr.from, s.attr.to); - assert_equal(#r.tags, 1); - end - - do - -- Test stanza - local s = _M.stanza("iq", { to = "touser", from = "fromuser", id = "123", type = "get" }) - :tag("child1"); - -- Make reply stanza - local r = error_reply(s); - assert_equal(r.name, s.name); - assert_equal(r.id, s.id); - assert_equal(r.attr.to, s.attr.from); - assert_equal(r.attr.from, s.attr.to); - assert_equal(r.attr.type, "error"); - assert_equal(#r.tags, 1); - end -end diff --git a/tests/test_util_throttle.lua b/tests/test_util_throttle.lua deleted file mode 100644 index 6d47238a..00000000 --- a/tests/test_util_throttle.lua +++ /dev/null @@ -1,26 +0,0 @@ - -local now = 0; -- wibbly-wobbly... timey-wimey... stuff -local function predictable_gettime() - return now; -end -local function later(n) - now = now + n; -- time passes at a different rate -end - -package.loaded["util.time"] = { - now = predictable_gettime; -} - -function create(create) - local a = create(3, 10); - - assert_equal(a:poll(1), true); -- 3 -> 2 - assert_equal(a:poll(1), true); -- 2 -> 1 - assert_equal(a:poll(1), true); -- 1 -> 0 - assert_equal(a:poll(1), false); -- MEEP, out of credits! - later(1); -- ... what about - assert_equal(a:poll(1), false); -- now? - Still no! - later(9); -- Later that day - assert_equal(a:poll(1), true); -- Should be back at 3 credits ... 2 -end - diff --git a/tests/test_util_uuid.lua b/tests/test_util_uuid.lua deleted file mode 100644 index 07d75025..00000000 --- a/tests/test_util_uuid.lua +++ /dev/null @@ -1,24 +0,0 @@ --- This tests the format, not the randomness - --- https://tools.ietf.org/html/rfc4122#section-4.4 - -local pattern = "^" .. table.concat({ - string.rep("%x", 8), - string.rep("%x", 4), - "4" .. -- version - string.rep("%x", 3), - "[89ab]" .. -- reserved bits of 1 and 0 - string.rep("%x", 3), - string.rep("%x", 12), -}, "%-") .. "$"; - -function generate(generate) - for _ = 1, 100 do - assert_is(generate():match(pattern)); - end -end - -function seed(seed) - assert_equal(seed("random string here"), nil, "seed doesn't return anything"); -end - diff --git a/tests/test_util_xml.lua b/tests/test_util_xml.lua deleted file mode 100644 index ba44da19..00000000 --- a/tests/test_util_xml.lua +++ /dev/null @@ -1,12 +0,0 @@ -function parse(parse) - local x = -[[<x xmlns:a="b"> - <y xmlns:a="c"> <!-- this overwrites 'a' --> - <a:z/> - </y> - <a:z/> <!-- prefix 'a' is nil here, but should be 'b' --> -</x> -]] - local stanza = parse(x); - assert_equal(stanza.tags[2].attr.xmlns, "b"); -end diff --git a/tests/test_util_xmppstream.lua b/tests/test_util_xmppstream.lua deleted file mode 100644 index 791cf999..00000000 --- a/tests/test_util_xmppstream.lua +++ /dev/null @@ -1,83 +0,0 @@ -function new(new_stream, _M) - local function test(xml, expect_success, ex) - local stanzas = {}; - local session = { notopen = true }; - local callbacks = { - stream_ns = "streamns"; - stream_tag = "stream"; - default_ns = "stanzans"; - streamopened = function (_session) - assert_equal(session, _session); - assert_equal(session.notopen, true); - _session.notopen = nil; - return true; - end; - handlestanza = function (_session, stanza) - assert_equal(session, _session); - assert_equal(_session.notopen, nil); - table.insert(stanzas, stanza); - end; - streamclosed = function (_session) - assert_equal(session, _session); - assert_equal(_session.notopen, nil); - _session.notopen = nil; - end; - } - if type(ex) == "table" then - for k, v in pairs(ex) do - if k ~= "_size_limit" then - callbacks[k] = v; - end - end - end - local stream = new_stream(session, callbacks, size_limit); - local ok, err = pcall(function () - assert(stream:feed(xml)); - end); - - if ok and type(expect_success) == "function" then - expect_success(stanzas); - end - assert_equal(not not ok, not not expect_success, "Expected "..(expect_success and ("success ("..tostring(err)..")") or "failure")); - end - - local function test_stanza(stanza, expect_success, ex) - return test([[<stream:stream xmlns:stream="streamns" xmlns="stanzans">]]..stanza, expect_success, ex); - end - - test([[<stream:stream xmlns:stream="streamns"/>]], true); - test([[<stream xmlns="streamns"/>]], true); - - test([[<stream1 xmlns="streamns"/>]], false); - test([[<stream xmlns="streamns1"/>]], false); - test("<>", false); - - test_stanza("<message/>", function (stanzas) - assert_equal(#stanzas, 1); - assert_equal(stanzas[1].name, "message"); - end); - test_stanza("< message>>>>/>\n", false); - - test_stanza([[<x xmlns:a="b"> - <y xmlns:a="c"> - <a:z/> - </y> - <a:z/> - </x>]], function (stanzas) - assert_equal(#stanzas, 1); - local s = stanzas[1]; - assert_equal(s.name, "x"); - assert_equal(#s.tags, 2); - - assert_equal(s.tags[1].name, "y"); - assert_equal(s.tags[1].attr.xmlns, nil); - - assert_equal(s.tags[1].tags[1].name, "z"); - assert_equal(s.tags[1].tags[1].attr.xmlns, "c"); - - assert_equal(s.tags[2].name, "z"); - assert_equal(s.tags[2].attr.xmlns, "b"); - - assert_equal(s.namespaces, nil); - end); -end diff --git a/tests/util/logger.lua b/tests/util/logger.lua deleted file mode 100644 index 44860d5d..00000000 --- a/tests/util/logger.lua +++ /dev/null @@ -1,45 +0,0 @@ --- Prosody IM --- Copyright (C) 2008-2010 Matthew Wild --- Copyright (C) 2008-2010 Waqas Hussain --- --- This project is MIT/X11 licensed. Please see the --- COPYING file in the source package for more information. --- - -local format = string.format; -local print = print; -local debug = debug; -local tostring = tostring; - -local getstyle, getstring = require "util.termcolours".getstyle, require "util.termcolours".getstring; -local do_pretty_printing = not os.getenv("WINDIR"); - -local _ENV = nil -local _M = {} - -local logstyles = {}; - ---TODO: This should be done in config, but we don't have proper config yet -if do_pretty_printing then - logstyles["info"] = getstyle("bold"); - logstyles["warn"] = getstyle("bold", "yellow"); - logstyles["error"] = getstyle("bold", "red"); -end - -function _M.init(name) - --name = nil; -- While this line is not commented, will automatically fill in file/line number info - return function (level, message, ...) - if level == "debug" or level == "info" then return; end - if not name then - local inf = debug.getinfo(3, 'Snl'); - level = level .. ","..tostring(inf.short_src):match("[^/]*$")..":"..inf.currentline; - end - if ... then - print(name, getstring(logstyles[level], level), format(message, ...)); - else - print(name, getstring(logstyles[level], level), message); - end - end -end - -return _M; diff --git a/tools/migration/migrator/prosody_sql.lua b/tools/migration/migrator/prosody_sql.lua index 2e902fbb..6df2b025 100644 --- a/tools/migration/migrator/prosody_sql.lua +++ b/tools/migration/migrator/prosody_sql.lua @@ -12,7 +12,7 @@ local tostring = tostring; local tonumber = tonumber; if not have_DBI then - error("LuaDBI (required for SQL support) was not found, please see http://prosody.im/doc/depends#luadbi", 0); + error("LuaDBI (required for SQL support) was not found, please see https://prosody.im/doc/depends#luadbi", 0); end local sql = require "util.sql"; diff --git a/util-src/Makefile b/util-src/GNUmakefile index f18d5a80..4b8540c5 100644 --- a/util-src/Makefile +++ b/util-src/GNUmakefile @@ -18,7 +18,7 @@ endif all: $(ALL) install: $(ALL) - $(INSTALL_DATA) $^ $(TARGET) + $(INSTALL_DATA) $? $(TARGET) clean: rm -f $(ALL) $(patsubst %.so,%.o,$(ALL)) diff --git a/util-src/crand.c b/util-src/crand.c index 7eea1f2b..160ac1f6 100644 --- a/util-src/crand.c +++ b/util-src/crand.c @@ -21,19 +21,22 @@ #define _DEFAULT_SOURCE -#include "lualib.h" -#include "lauxlib.h" - +#include <stdlib.h> #include <string.h> #include <errno.h> +#include "lualib.h" +#include "lauxlib.h" + #if defined(WITH_GETRANDOM) #ifndef __GLIBC_PREREQ +/* Not compiled with glibc at all */ #define __GLIBC_PREREQ(a,b) 0 #endif #if ! __GLIBC_PREREQ(2,25) +/* Not compiled with a glibc that provides getrandom() */ #include <unistd.h> #include <sys/syscall.h> @@ -49,45 +52,66 @@ int getrandom(void *buf, size_t buflen, unsigned int flags) { #include <sys/random.h> #endif -#elif defined(WITH_ARC4RANDOM) -#include <stdlib.h> #elif defined(WITH_OPENSSL) #include <openssl/rand.h> +#elif defined(WITH_ARC4RANDOM) +#ifdef __linux__ +#include <bsd/stdlib.h> +#endif #else #error util.crand compiled without a random source #endif +#ifndef SMALLBUFSIZ +#define SMALLBUFSIZ 32 +#endif + int Lrandom(lua_State *L) { - int ret = 0; - size_t len = (size_t)luaL_checkinteger(L, 1); - void *buf = lua_newuserdata(L, len); + char smallbuf[SMALLBUFSIZ]; + char *buf = &smallbuf[0]; + const lua_Integer l = luaL_checkinteger(L, 1); + const size_t len = l; + luaL_argcheck(L, l >= 0, 1, "must be > 0"); + + if(len == 0) { + lua_pushliteral(L, ""); + return 1; + } + + if(len > SMALLBUFSIZ) { + buf = lua_newuserdata(L, len); + } #if defined(WITH_GETRANDOM) /* * This acts like a read from /dev/urandom with the exception that it * *does* block if the entropy pool is not yet initialized. */ - ret = getrandom(buf, len, 0); + int left = len; + char *p = buf; - if(ret < 0) { - lua_pushstring(L, strerror(errno)); - return lua_error(L); - } + do { + int ret = getrandom(p, left, 0); + + if(ret < 0) { + lua_pushstring(L, strerror(errno)); + return lua_error(L); + } + + p += ret; + left -= ret; + } while(left > 0); #elif defined(WITH_ARC4RANDOM) arc4random_buf(buf, len); - ret = len; #elif defined(WITH_OPENSSL) + if(!RAND_status()) { lua_pushliteral(L, "OpenSSL PRNG not seeded"); return lua_error(L); } - ret = RAND_bytes(buf, len); - - if(ret == 1) { - ret = len; - } else { + if(RAND_bytes((unsigned char *)buf, len) != 1) { /* TODO ERR_get_error() */ lua_pushstring(L, "RAND_bytes() failed"); return lua_error(L); @@ -95,7 +119,7 @@ int Lrandom(lua_State *L) { #endif - lua_pushlstring(L, buf, ret); + lua_pushlstring(L, buf, len); return 1; } diff --git a/util-src/makefile b/util-src/makefile new file mode 100644 index 00000000..43ce2213 --- /dev/null +++ b/util-src/makefile @@ -0,0 +1,44 @@ +include ../config.unix + +CFLAGS+=-I$(LUA_INCDIR) + +INSTALL_DATA=install -m644 +TARGET?=../util/ + +ALL=encodings.so hashes.so net.so pposix.so signal.so table.so ringbuffer.so + +.ifdef $(RANDOM) +ALL+=crand.so +.endif + +.PHONY: all install clean +.SUFFIXES: .c .o .so + +all: $(ALL) + +install: $(ALL) + $(INSTALL_DATA) $(ALL) $(TARGET) + +clean: + rm -f $(ALL) $(patsubst %.so,%.o,$(ALL)) + +encodings.so: encodings.o + $(LD) $(LDFLAGS) -o $@ $< $(LDLIBS) $(IDNA_LIBS) + +hashes.so: hashes.o + $(LD) $(LDFLAGS) -o $@ $< $(LDLIBS) $(OPENSSL_LIBS) + +crand.o: crand.c + $(CC) $(CFLAGS) -DWITH_$(RANDOM) -c -o $@ $< + +crand.so: crand.o + $(LD) $(LDFLAGS) -o $@ $< $(LDLIBS) $(RANDOM_LIBS) + +%.so: %.o + $(LD) $(LDFLAGS) -o $@ $< $(LDLIBS) + +.c.o: + $(CC) $(CFLAGS) -c -o $@ $< + +.o.so: + $(LD) $(LDFLAGS) -o $@ $< $(LDLIBS) diff --git a/util-src/net.c b/util-src/net.c index bb159d57..9ff01a71 100644 --- a/util-src/net.c +++ b/util-src/net.c @@ -125,12 +125,75 @@ static int lc_local_addresses(lua_State *L) { return 1; } +static int lc_pton(lua_State *L) { + char buf[16]; + const char *ipaddr = luaL_checkstring(L, 1); + int errno_ = 0; + int family = strchr(ipaddr, ':') ? AF_INET6 : AF_INET; + + switch(inet_pton(family, ipaddr, &buf)) { + case 1: + lua_pushlstring(L, buf, family == AF_INET6 ? 16 : 4); + return 1; + + case -1: + errno_ = errno; + lua_pushnil(L); + lua_pushstring(L, strerror(errno_)); + lua_pushinteger(L, errno_); + return 3; + + default: + case 0: + lua_pushnil(L); + lua_pushstring(L, strerror(EINVAL)); + lua_pushinteger(L, EINVAL); + return 3; + } + +} + +static int lc_ntop(lua_State *L) { + char buf[INET6_ADDRSTRLEN]; + int family; + int errno_; + size_t l; + const char *ipaddr = luaL_checklstring(L, 1, &l); + + if(l == 16) { + family = AF_INET6; + } + else if(l == 4) { + family = AF_INET; + } + else { + lua_pushnil(L); + lua_pushstring(L, strerror(EAFNOSUPPORT)); + lua_pushinteger(L, EAFNOSUPPORT); + return 3; + } + + if(!inet_ntop(family, ipaddr, buf, INET6_ADDRSTRLEN)) + { + errno_ = errno; + lua_pushnil(L); + lua_pushstring(L, strerror(errno_)); + lua_pushinteger(L, errno_); + return 3; + } + + lua_pushstring(L, (const char *)(&buf)); + return 1; +} + int luaopen_util_net(lua_State *L) { #if (LUA_VERSION_NUM > 501) luaL_checkversion(L); #endif luaL_Reg exports[] = { { "local_addresses", lc_local_addresses }, + { "pton", lc_pton }, + { "ntop", lc_ntop }, { NULL, NULL } }; diff --git a/util-src/ringbuffer.c b/util-src/ringbuffer.c index 8d9e49e6..8f9013f7 100644 --- a/util-src/ringbuffer.c +++ b/util-src/ringbuffer.c @@ -39,10 +39,12 @@ int find(ringbuffer *b, const char *s, size_t l) { return 0; } + /* look for a matching first byte */ for(i = 0; i <= b->blen - l; i++) { if(b->buffer[(b->rpos + i) % b->alen] == *s) { m = 1; + /* check if the following byte also match */ for(j = 1; j < l; j++) if(b->buffer[(b->rpos + i + j) % b->alen] != s[j]) { m = 0; @@ -58,6 +60,10 @@ int find(ringbuffer *b, const char *s, size_t l) { return 0; } +/* + * Find first position of a substring in buffer + * (buffer, string) -> number + */ int rb_find(lua_State *L) { size_t l, m; ringbuffer *b = luaL_checkudata(L, 1, "ringbuffer_mt"); @@ -72,6 +78,31 @@ int rb_find(lua_State *L) { return 0; } +/* + * Move read position forward without returning the data + * (buffer, number) -> boolean + */ +int rb_discard(lua_State *L) { + ringbuffer *b = luaL_checkudata(L, 1, "ringbuffer_mt"); + size_t r = luaL_checkinteger(L, 2); + + if(r > b->blen) { + lua_pushboolean(L, 0); + return 1; + } + + b->blen -= r; + b->rpos += r; + modpos(b); + + lua_pushboolean(L, 1); + return 1; +} + +/* + * Read bytes from buffer + * (buffer, number, boolean?) -> string + */ int rb_read(lua_State *L) { ringbuffer *b = luaL_checkudata(L, 1, "ringbuffer_mt"); size_t r = luaL_checkinteger(L, 2); @@ -83,6 +114,7 @@ int rb_read(lua_State *L) { } if((b->rpos + r) > b->alen) { + /* Substring wraps around to the beginning of the buffer */ lua_pushlstring(L, &b->buffer[b->rpos], b->alen - b->rpos); lua_pushlstring(L, b->buffer, r - (b->alen - b->rpos)); lua_concat(L, 2); @@ -99,6 +131,10 @@ int rb_read(lua_State *L) { return 1; } +/* + * Read buffer until first occurrence of a substring + * (buffer, string) -> string + */ int rb_readuntil(lua_State *L) { size_t l, m; ringbuffer *b = luaL_checkudata(L, 1, "ringbuffer_mt"); @@ -114,6 +150,10 @@ int rb_readuntil(lua_State *L) { return 0; } +/* + * Write bytes into the buffer + * (buffer, string) -> integer + */ int rb_write(lua_State *L) { size_t l, w = 0; ringbuffer *b = luaL_checkudata(L, 1, "ringbuffer_mt"); @@ -191,6 +231,8 @@ int luaopen_util_ringbuffer(lua_State *L) { { lua_pushcfunction(L, rb_find); lua_setfield(L, -2, "find"); + lua_pushcfunction(L, rb_discard); + lua_setfield(L, -2, "discard"); lua_pushcfunction(L, rb_read); lua_setfield(L, -2, "read"); lua_pushcfunction(L, rb_readuntil); diff --git a/util/adhoc.lua b/util/adhoc.lua index 17c9eee5..d81b8242 100644 --- a/util/adhoc.lua +++ b/util/adhoc.lua @@ -1,3 +1,5 @@ +-- luacheck: ignore 212/self + local function new_simple_form(form, result_handler) return function(self, data, state) if state then diff --git a/util/array.lua b/util/array.lua index 150b4355..1a8ffec7 100644 --- a/util/array.lua +++ b/util/array.lua @@ -19,7 +19,7 @@ local type = type; local array = {}; local array_base = {}; local array_methods = {}; -local array_mt = { __index = array_methods, __tostring = function (self) return "{"..self:concat(", ").."}"; end }; +local array_mt = { __index = array_methods, __name = "array", __tostring = function (self) return "{"..self:concat(", ").."}"; end }; local function new_array(self, t, _s, _var) if type(t) == "function" then -- Assume iterator diff --git a/util/async.lua b/util/async.lua new file mode 100644 index 00000000..0d19af6e --- /dev/null +++ b/util/async.lua @@ -0,0 +1,253 @@ +local logger = require "util.logger"; +local log = logger.init("util.async"); +local new_id = require "util.id".short; + +local function checkthread() + local thread, main = coroutine.running(); + if not thread or main then + error("Not running in an async context, see https://prosody.im/doc/developers/util/async"); + end + return thread; +end + +local function runner_from_thread(thread) + local level = 0; + -- Find the 'level' of the top-most function (0 == current level, 1 == caller, ...) + while debug.getinfo(thread, level, "") do level = level + 1; end + local name, runner = debug.getlocal(thread, level-1, 1); + if name ~= "self" or type(runner) ~= "table" or runner.thread ~= thread then + return nil; + end + return runner; +end + +local function call_watcher(runner, watcher_name, ...) + local watcher = runner.watchers[watcher_name]; + if not watcher then + return false; + end + runner:log("debug", "Calling '%s' watcher", watcher_name); + local ok, err = pcall(watcher, runner, ...); -- COMPAT: Switch to xpcall after Lua 5.1 + if not ok then + runner:log("error", "Error in '%s' watcher: %s", watcher_name, err); + return nil, err; + end + return true; +end + +local function runner_continue(thread) + -- ASSUMPTION: runner is in 'waiting' state (but we don't have the runner to know for sure) + if coroutine.status(thread) ~= "suspended" then -- This should suffice + log("error", "unexpected async state: thread not suspended"); + return false; + end + local ok, state, runner = coroutine.resume(thread); + if not ok then + local err = state; + -- Running the coroutine failed, which means we have to find the runner manually, + -- in order to inform the error handler + runner = runner_from_thread(thread); + if not runner then + log("error", "unexpected async state: unable to locate runner during error handling"); + return false; + end + call_watcher(runner, "error", debug.traceback(thread, err)); + runner.state, runner.thread = "ready", nil; + return runner:run(); + elseif state == "ready" then + -- If state is 'ready', it is our responsibility to update runner.state from 'waiting'. + -- We also have to :run(), because the queue might have further items that will not be + -- processed otherwise. FIXME: It's probably best to do this in a nexttick (0 timer). + runner.state = "ready"; + runner:run(); + end + return true; +end + +local function waiter(num) + local thread = checkthread(); + num = num or 1; + local waiting; + return function () + if num == 0 then return; end -- already done + waiting = true; + coroutine.yield("wait"); + end, function () + num = num - 1; + if num == 0 and waiting then + runner_continue(thread); + elseif num < 0 then + error("done() called too many times"); + end + end; +end + +local function guarder() + local guards = {}; + local default_id = {}; + return function (id, func) + id = id or default_id; + local thread = checkthread(); + local guard = guards[id]; + if not guard then + guard = {}; + guards[id] = guard; + log("debug", "New guard!"); + else + table.insert(guard, thread); + log("debug", "Guarded. %d threads waiting.", #guard) + coroutine.yield("wait"); + end + local function exit() + local next_waiting = table.remove(guard, 1); + if next_waiting then + log("debug", "guard: Executing next waiting thread (%d left)", #guard) + runner_continue(next_waiting); + else + log("debug", "Guard off duty.") + guards[id] = nil; + end + end + if func then + func(); + exit(); + return; + end + return exit; + end; +end + +local runner_mt = {}; +runner_mt.__index = runner_mt; + +local function runner_create_thread(func, self) + local thread = coroutine.create(function (self) -- luacheck: ignore 432/self + while true do + func(coroutine.yield("ready", self)); + end + end); + debug.sethook(thread, debug.gethook()); + assert(coroutine.resume(thread, self)); -- Start it up, it will return instantly to wait for the first input + return thread; +end + +local function default_error_watcher(runner, err) + runner:log("error", "Encountered error: %s", err); + error(err); +end +local function default_func(f) f(); end +local function runner(func, watchers, data) + local id = new_id(); + local _log = logger.init("runner" .. id); + return setmetatable({ func = func or default_func, thread = false, state = "ready", notified_state = "ready", + queue = {}, watchers = watchers or { error = default_error_watcher }, data = data, id = id, _log = _log; } + , runner_mt); +end + +-- Add a task item for the runner to process +function runner_mt:run(input) + if input ~= nil then + table.insert(self.queue, input); + self:log("debug", "queued new work item, %d items queued", #self.queue); + end + if self.state ~= "ready" then + -- The runner is busy. Indicate that the task item has been + -- queued, and return information about the current runner state + return true, self.state, #self.queue; + end + + local q, thread = self.queue, self.thread; + if not thread or coroutine.status(thread) == "dead" then + self:log("debug", "creating new coroutine"); + -- Create a new coroutine for this runner + thread = runner_create_thread(self.func, self); + self.thread = thread; + end + + -- Process task item(s) while the queue is not empty, and we're not blocked + local n, state, err = #q, self.state, nil; + self.state = "running"; + self:log("debug", "running main loop"); + while n > 0 and state == "ready" and not err do + local consumed; + -- Loop through queue items, and attempt to run them + for i = 1,n do + local queued_input = q[i]; + local ok, new_state = coroutine.resume(thread, queued_input); + if not ok then + -- There was an error running the coroutine, save the error, mark runner as ready to begin again + consumed, state, err = i, "ready", debug.traceback(thread, new_state); + self.thread = nil; + break; + elseif new_state == "wait" then + -- Runner is blocked on waiting for a task item to complete + consumed, state = i, "waiting"; + break; + end + end + -- Loop ended - either queue empty because all tasks passed without blocking (consumed == nil) + -- or runner is blocked/errored, and consumed will contain the number of tasks processed so far + if not consumed then consumed = n; end + -- Remove consumed items from the queue array + if q[n+1] ~= nil then + n = #q; + end + for i = 1, n do + q[i] = q[consumed+i]; + end + n = #q; + end + -- Runner processed all items it can, so save current runner state + self.state = state; + if err or state ~= self.notified_state then + self:log("debug", "changed state from %s to %s", self.notified_state, err and ("error ("..state..")") or state); + if err then + state = "error" + else + self.notified_state = state; + end + local handler = self.watchers[state]; + if handler then handler(self, err); end + end + if n > 0 then + return self:run(); + end + return true, state, n; +end + +-- Add a task item to the queue without invoking the runner, even if it is idle +function runner_mt:enqueue(input) + table.insert(self.queue, input); + self:log("debug", "queued new work item, %d items queued", #self.queue); + return self; +end + +function runner_mt:log(level, fmt, ...) + return self._log(level, fmt, ...); +end + +function runner_mt:onready(f) + self.watchers.ready = f; + return self; +end + +function runner_mt:onwaiting(f) + self.watchers.waiting = f; + return self; +end + +function runner_mt:onerror(f) + self.watchers.error = f; + return self; +end + +local function ready() + return pcall(checkthread); +end + +return { + ready = ready; + waiter = waiter; + guarder = guarder; + runner = runner; +}; diff --git a/util/cache.lua b/util/cache.lua index 9c141bb6..a5fd5e6d 100644 --- a/util/cache.lua +++ b/util/cache.lua @@ -116,6 +116,25 @@ function cache_methods:tail() return tail.key, tail.value; end +function cache_methods:resize(new_size) + new_size = assert(tonumber(new_size), "cache size must be a number"); + new_size = math.floor(new_size); + assert(new_size > 0, "cache size must be greater than zero"); + local on_evict = self._on_evict; + while self._count > new_size do + local tail = self._tail; + local evicted_key, evicted_value = tail.key, tail.value; + if on_evict ~= nil and (on_evict == false or on_evict(evicted_key, evicted_value) == false) then + -- Cache is full, and we're not allowed to evict + return false; + end + _remove(self, tail); + self._data[evicted_key] = nil; + end + self.size = new_size; + return true; +end + function cache_methods:table() --luacheck: ignore 212/t if not self.proxy_table then @@ -139,6 +158,13 @@ function cache_methods:table() return self.proxy_table; end +function cache_methods:clear() + self._data = {}; + self._count = 0; + self._head = nil; + self._tail = nil; +end + local function new(size, on_evict) size = assert(tonumber(size), "cache size must be a number"); size = math.floor(size); diff --git a/util/caps.lua b/util/caps.lua index cd5ff9c0..de492edb 100644 --- a/util/caps.lua +++ b/util/caps.lua @@ -13,6 +13,7 @@ local t_insert, t_sort, t_concat = table.insert, table.sort, table.concat; local ipairs = ipairs; local _ENV = nil; +-- luacheck: std none local function calculate_hash(disco_info) local identities, features, extensions = {}, {}, {}; diff --git a/util/dataforms.lua b/util/dataforms.lua index 469ce976..e48f6879 100644 --- a/util/dataforms.lua +++ b/util/dataforms.lua @@ -8,12 +8,13 @@ local setmetatable = setmetatable; local ipairs = ipairs; -local tostring, type, next = tostring, type, next; +local type, next = type, next; local t_concat = table.concat; local st = require "util.stanza"; local jid_prep = require "util.jid".prep; local _ENV = nil; +-- luacheck: std none local xmlns_forms = 'jabber:x:data'; @@ -37,6 +38,10 @@ function form_t.form(layout, data, formtype) -- Add field tag form:tag("field", { type = field_type, var = field.name, label = field.label }); + if field.desc then + form:text_tag("desc", field.desc); + end + local value = (data and data[field.name]) or field.value; if value then @@ -48,7 +53,7 @@ function form_t.form(layout, data, formtype) :add_child(value) :up(); else - form:tag("value"):text(tostring(value)):up(); + form:tag("value"):text(value):up(); end elseif field_type == "boolean" then form:tag("value"):text((value and "1") or "0"):up(); @@ -78,7 +83,7 @@ function form_t.form(layout, data, formtype) has_default = true; end else - form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up(); + form:tag("option", { label= val }):tag("value"):text(val):up():up(); end end end @@ -94,7 +99,7 @@ function form_t.form(layout, data, formtype) form:tag("value"):text(val.value):up(); end else - form:tag("option", { label= val }):tag("value"):text(tostring(val)):up():up(); + form:tag("option", { label= val }):tag("value"):text(val):up():up(); end end end @@ -248,8 +253,24 @@ field_readers["hidden"] = return field_tag:get_child_text("value"); end + +local function get_form_type(form) + if not st.is_stanza(form) then + return nil, "not a stanza object"; + elseif form.attr.xmlns ~= "jabber:x:data" or form.name ~= "x" then + return nil, "not a dataform element"; + end + for field in form:childtags("field") do + if field.attr.var == "FORM_TYPE" then + return field:get_child_text("value"); + end + end + return ""; +end + return { new = new; + get_type = get_form_type; }; diff --git a/util/datamanager.lua b/util/datamanager.lua index bd8fb7bb..cf96887b 100644 --- a/util/datamanager.lua +++ b/util/datamanager.lua @@ -40,9 +40,10 @@ pcall(function() end); local _ENV = nil; +-- luacheck: std none ---- utils ----- -local encode, decode; +local encode, decode, store_encode; do local urlcodes = setmetatable({}, { __index = function (t, k) t[k] = char(tonumber(k, 16)); return t[k]; end }); @@ -53,6 +54,12 @@ do encode = function (s) return s and (s:gsub("%W", function (c) return format("%%%02x", c:byte()); end)); end + + -- Special encode function for store names, which historically were unencoded. + -- All currently known stores use a-z and underscore, so this one preserves underscores. + store_encode = function (s) + return s and (s:gsub("[^_%w]", function (c) return format("%%%02x", c:byte()); end)); + end end if not atomic_append then @@ -119,6 +126,7 @@ local function getpath(username, host, datastore, ext, create) ext = ext or "dat"; host = (host and encode(host)) or "_global"; username = username and encode(username); + datastore = store_encode(datastore); if username then if create then mkdir(mkdir(mkdir(data_path).."/"..host).."/"..datastore); end return format("%s/%s/%s/%s.%s", data_path, host, datastore, username, ext); diff --git a/util/datetime.lua b/util/datetime.lua index abb4e867..06be9fc2 100644 --- a/util/datetime.lua +++ b/util/datetime.lua @@ -15,6 +15,7 @@ local os_difftime = os.difftime; local tonumber = tonumber; local _ENV = nil; +-- luacheck: std none local function date(t) return os_date("!%Y-%m-%d", t); diff --git a/util/debug.lua b/util/debug.lua index 00f476d0..9a28395a 100644 --- a/util/debug.lua +++ b/util/debug.lua @@ -47,6 +47,7 @@ local function get_upvalues_table(func) for upvalue_num = 1, math.huge do local name, value = debug.getupvalue(func, upvalue_num); if not name then break; end + if name == "" then name = ("[%d]"):format(upvalue_num); end table.insert(upvalues, { name = name, value = value }); end end @@ -112,7 +113,9 @@ end local function build_source_boundary_marker(last_source_desc) local padding = string.rep("-", math.floor(((optimal_line_length - 6) - #last_source_desc)/2)); - return getstring(styles.boundary_padding, "v"..padding).." "..getstring(styles.filename, last_source_desc).." "..getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v ")); + return getstring(styles.boundary_padding, "v"..padding).." ".. + getstring(styles.filename, last_source_desc).." ".. + getstring(styles.boundary_padding, padding..(#last_source_desc%2==0 and "-v" or "v ")); end local function _traceback(thread, message, level) @@ -142,9 +145,9 @@ local function _traceback(thread, message, level) local last_source_desc; local lines = {}; - for nlevel, level in ipairs(levels) do - local info = level.info; - local line = "..."; + for nlevel, current_level in ipairs(levels) do + local info = current_level.info; + local line; local func_type = info.namewhat.." "; local source_desc = (info.short_src == "[C]" and "C code") or info.short_src or "Unknown"; if func_type == " " then func_type = ""; end; @@ -160,7 +163,9 @@ local function _traceback(thread, message, level) if func_type == "global " or func_type == "local " then func_type = func_type.."function "; end - line = "[Lua] "..getstring(styles.location, info.short_src.." line "..info.currentline).." in "..func_type..getstring(styles.funcname, name).." (defined on line "..info.linedefined..")"; + line = "[Lua] "..getstring(styles.location, info.short_src.." line ".. + info.currentline).." in "..func_type..getstring(styles.funcname, name).. + " (defined on line "..info.linedefined..")"; end if source_desc ~= last_source_desc then -- Venturing into a new source, add marker for previous last_source_desc = source_desc; @@ -169,13 +174,13 @@ local function _traceback(thread, message, level) nlevel = nlevel-1; table.insert(lines, "\t"..(nlevel==0 and ">" or " ")..getstring(styles.level_num, "("..nlevel..") ")..line); local npadding = (" "):rep(#tostring(nlevel)); - if level.locals then - local locals_str = string_from_var_table(level.locals, optimal_line_length, "\t "..npadding); + if current_level.locals then + local locals_str = string_from_var_table(current_level.locals, optimal_line_length, "\t "..npadding); if locals_str then table.insert(lines, "\t "..npadding.."Locals: "..locals_str); end end - local upvalues_str = string_from_var_table(level.upvalues, optimal_line_length, "\t "..npadding); + local upvalues_str = string_from_var_table(current_level.upvalues, optimal_line_length, "\t "..npadding); if upvalues_str then table.insert(lines, "\t "..npadding.."Upvals: "..upvalues_str); end diff --git a/util/dependencies.lua b/util/dependencies.lua index de840241..9b0afd77 100644 --- a/util/dependencies.lua +++ b/util/dependencies.lua @@ -28,7 +28,7 @@ local function missingdep(name, sources, msg) end print(""); print(msg or (name.." is required for Prosody to run, so we will now exit.")); - print("More help can be found on our website, at http://prosody.im/doc/depends"); + print("More help can be found on our website, at https://prosody.im/doc/depends"); print("**************************"); print(""); end @@ -40,7 +40,7 @@ end package.preload["util.ztact"] = function () if not package.loaded["core.loggingmanager"] then error("util.ztact has been removed from Prosody and you need to fix your config " - .."file. More information can be found at http://prosody.im/doc/packagers#ztact", 0); + .."file. More information can be found at https://prosody.im/doc/packagers#ztact", 0); else error("module 'util.ztact' has been deprecated in Prosody 0.8."); end @@ -156,7 +156,7 @@ local function log_warnings() if ssl then local major, minor, veryminor, patched = ssl._VERSION:match("(%d+)%.(%d+)%.?(%d*)(M?)"); if not major or ((tonumber(major) == 0 and (tonumber(minor) or 0) <= 3 and (tonumber(veryminor) or 0) <= 2) and patched ~= "M") then - prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see http://prosody.im/doc/depends"); + prosody.log("error", "This version of LuaSec contains a known bug that causes disconnects, see https://prosody.im/doc/depends"); end end local lxp = softreq"lxp"; @@ -165,7 +165,7 @@ local function log_warnings() prosody.log("error", "The version of LuaExpat on your system leaves Prosody " .."vulnerable to denial-of-service attacks. You should upgrade to " .."LuaExpat 1.3.0 or higher as soon as possible. See " - .."http://prosody.im/doc/depends#luaexpat for more information."); + .."https://prosody.im/doc/depends#luaexpat for more information."); end if not lxp.new({}).getcurrentbytecount then prosody.log("error", "The version of LuaExpat on your system does not support " @@ -173,7 +173,7 @@ local function log_warnings() .."networks (e.g. the internet) vulnerable to denial-of-service " .."attacks. You should upgrade to LuaExpat 1.3.0 or higher as " .."soon as possible. See " - .."http://prosody.im/doc/depends#luaexpat for more information."); + .."https://prosody.im/doc/depends#luaexpat for more information."); end end end diff --git a/util/envload.lua b/util/envload.lua index 926f20c0..6182a1f9 100644 --- a/util/envload.lua +++ b/util/envload.lua @@ -4,7 +4,7 @@ -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- --- luacheck: ignore 113/setfenv +-- luacheck: ignore 113/setfenv 113/loadstring local load, loadstring, setfenv = load, loadstring, setfenv; local io_open = io.open; diff --git a/util/events.lua b/util/events.lua index 6e13619c..0bf0ddcb 100644 --- a/util/events.lua +++ b/util/events.lua @@ -15,6 +15,7 @@ local setmetatable = setmetatable; local next = next; local _ENV = nil; +-- luacheck: std none local function new() -- Map event name to ordered list of handlers (lazily built): handlers[event_name] = array_of_handler_functions @@ -26,7 +27,7 @@ local function new() -- Event map: event_map[handler_function] = priority_number local event_map = {}; -- Called on-demand to build handlers entries - local function _rebuild_index(handlers, event) + local function _rebuild_index(self, event) local _handlers = event_map[event]; if not _handlers or next(_handlers) == nil then return; end local index = {}; @@ -34,7 +35,7 @@ local function new() t_insert(index, handler); end t_sort(index, function(a, b) return _handlers[a] > _handlers[b]; end); - handlers[event] = index; + self[event] = index; return index; end; setmetatable(handlers, { __index = _rebuild_index }); @@ -61,13 +62,13 @@ local function new() local function get_handlers(event) return handlers[event]; end; - local function add_handlers(handlers) - for event, handler in pairs(handlers) do + local function add_handlers(self) + for event, handler in pairs(self) do add_handler(event, handler); end end; - local function remove_handlers(handlers) - for event, handler in pairs(handlers) do + local function remove_handlers(self) + for event, handler in pairs(self) do remove_handler(event, handler); end end; @@ -81,6 +82,7 @@ local function new() end end; local function fire_event(event_name, event_data) + -- luacheck: ignore 432/event_name 432/event_data local w = wrappers[event_name] or global_wrappers; if w then local curr_wrapper = #w; diff --git a/util/filters.lua b/util/filters.lua index f405c0bd..f30dfd9c 100644 --- a/util/filters.lua +++ b/util/filters.lua @@ -9,6 +9,7 @@ local t_insert, t_remove = table.insert, table.remove; local _ENV = nil; +-- luacheck: std none local new_filter_hooks = {}; diff --git a/util/format.lua b/util/format.lua index 5f2b12be..c5e513fa 100644 --- a/util/format.lua +++ b/util/format.lua @@ -4,11 +4,10 @@ local tostring = tostring; local select = select; -local assert = assert; -local unpack = unpack; +local unpack = table.unpack or unpack; -- luacheck: ignore 113/unpack local type = type; -local function format(format, ...) +local function format(formatstring, ...) local args, args_length = { ... }, select('#', ...); -- format specifier spec: @@ -25,7 +24,7 @@ local function format(format, ...) -- process each format specifier local i = 0; - format = format:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec) + formatstring = formatstring:gsub("%%[^cdiouxXaAeEfgGqs%%]*[cdiouxXaAeEfgGqs%%]", function(spec) if spec ~= "%%" then i = i + 1; local arg = args[i]; @@ -54,21 +53,12 @@ local function format(format, ...) else args[i] = tostring(arg); end - format = format .. " [%s]" + formatstring = formatstring .. " [%s]" end - return format:format(unpack(args)); -end - -local function test() - assert(format("%s", "hello") == "hello"); - assert(format("%s") == "<nil>"); - assert(format("%s", true) == "true"); - assert(format("%d", true) == "[true]"); - assert(format("%%", true) == "% [true]"); + return formatstring:format(unpack(args)); end return { format = format; - test = test; }; diff --git a/util/import.lua b/util/import.lua index c2b9dce1..8ecfe43c 100644 --- a/util/import.lua +++ b/util/import.lua @@ -8,9 +8,9 @@ -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack or unpack; --luacheck: ignore 113 143 local t_insert = table.insert; -function import(module, ...) +function _G.import(module, ...) local m = package.loaded[module] or require(module); if type(m) == "table" and ... then local ret = {}; diff --git a/util/indexedbheap.lua b/util/indexedbheap.lua new file mode 100644 index 00000000..7f193d54 --- /dev/null +++ b/util/indexedbheap.lua @@ -0,0 +1,157 @@ + +local setmetatable = setmetatable; +local math_floor = math.floor; +local t_remove = table.remove; + +local function _heap_insert(self, item, sync, item2, index) + local pos = #self + 1; + while true do + local half_pos = math_floor(pos / 2); + if half_pos == 0 or item > self[half_pos] then break; end + self[pos] = self[half_pos]; + sync[pos] = sync[half_pos]; + index[sync[pos]] = pos; + pos = half_pos; + end + self[pos] = item; + sync[pos] = item2; + index[item2] = pos; +end + +local function _percolate_up(self, k, sync, index) + local tmp = self[k]; + local tmp_sync = sync[k]; + while k ~= 1 do + local parent = math_floor(k/2); + if tmp < self[parent] then break; end + self[k] = self[parent]; + sync[k] = sync[parent]; + index[sync[k]] = k; + k = parent; + end + self[k] = tmp; + sync[k] = tmp_sync; + index[tmp_sync] = k; + return k; +end + +local function _percolate_down(self, k, sync, index) + local tmp = self[k]; + local tmp_sync = sync[k]; + local size = #self; + local child = 2*k; + while 2*k <= size do + if child ~= size and self[child] > self[child + 1] then + child = child + 1; + end + if tmp > self[child] then + self[k] = self[child]; + sync[k] = sync[child]; + index[sync[k]] = k; + else + break; + end + + k = child; + child = 2*k; + end + self[k] = tmp; + sync[k] = tmp_sync; + index[tmp_sync] = k; + return k; +end + +local function _heap_pop(self, sync, index) + local size = #self; + if size == 0 then return nil; end + + local result = self[1]; + local result_sync = sync[1]; + index[result_sync] = nil; + if size == 1 then + self[1] = nil; + sync[1] = nil; + return result, result_sync; + end + self[1] = t_remove(self); + sync[1] = t_remove(sync); + index[sync[1]] = 1; + + _percolate_down(self, 1, sync, index); + + return result, result_sync; +end + +local indexed_heap = {}; + +function indexed_heap:insert(item, priority, id) + if id == nil then + id = self.current_id; + self.current_id = id + 1; + end + self.items[id] = item; + _heap_insert(self.priorities, priority, self.ids, id, self.index); + return id; +end +function indexed_heap:pop() + local priority, id = _heap_pop(self.priorities, self.ids, self.index); + if id then + local item = self.items[id]; + self.items[id] = nil; + return priority, item, id; + end +end +function indexed_heap:peek() + return self.priorities[1]; +end +function indexed_heap:reprioritize(id, priority) + local k = self.index[id]; + if k == nil then return; end + self.priorities[k] = priority; + + k = _percolate_up(self.priorities, k, self.ids, self.index); + _percolate_down(self.priorities, k, self.ids, self.index); +end +function indexed_heap:remove_index(k) + local result = self.priorities[k]; + if result == nil then return; end + + local result_sync = self.ids[k]; + local item = self.items[result_sync]; + local size = #self.priorities; + + self.priorities[k] = self.priorities[size]; + self.ids[k] = self.ids[size]; + self.index[self.ids[k]] = k; + + t_remove(self.priorities); + t_remove(self.ids); + + self.index[result_sync] = nil; + self.items[result_sync] = nil; + + if size > k then + k = _percolate_up(self.priorities, k, self.ids, self.index); + _percolate_down(self.priorities, k, self.ids, self.index); + end + + return result, item, result_sync; +end +function indexed_heap:remove(id) + return self:remove_index(self.index[id]); +end + +local mt = { __index = indexed_heap }; + +local _M = { + create = function() + return setmetatable({ + ids = {}; -- heap of ids, sync'd with priorities + items = {}; -- map id->items + priorities = {}; -- heap of priorities + index = {}; -- map of id->index of id in ids + current_id = 1.5 + }, mt); + end +}; +return _M; diff --git a/util/ip.lua b/util/ip.lua index 81a98ef7..0ec9e297 100644 --- a/util/ip.lua +++ b/util/ip.lua @@ -5,69 +5,76 @@ -- COPYING file in the source package for more information. -- +local net = require "util.net"; +local hex = require "util.hex"; + local ip_methods = {}; -local ip_mt = { __index = function (ip, key) return (ip_methods[key])(ip); end, - __tostring = function (ip) return ip.addr; end, - __eq = function (ipA, ipB) return ipA.addr == ipB.addr; end}; -local hex2bits = { ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111" }; + +local ip_mt = { + __index = function (ip, key) + local method = ip_methods[key]; + if not method then return nil; end + local ret = method(ip); + ip[key] = ret; + return ret; + end, + __tostring = function (ip) return ip.addr; end, + __eq = function (ipA, ipB) return ipA.packed == ipB.packed; end +}; + +local hex2bits = { + ["0"] = "0000", ["1"] = "0001", ["2"] = "0010", ["3"] = "0011", + ["4"] = "0100", ["5"] = "0101", ["6"] = "0110", ["7"] = "0111", + ["8"] = "1000", ["9"] = "1001", ["A"] = "1010", ["B"] = "1011", + ["C"] = "1100", ["D"] = "1101", ["E"] = "1110", ["F"] = "1111", +}; local function new_ip(ipStr, proto) - if not proto then - local sep = ipStr:match("^%x+(.)"); - if sep == ":" or (not(sep) and ipStr:sub(1,1) == ":") then - proto = "IPv6" - elseif sep == "." then - proto = "IPv4" - end - if not proto then - return nil, "invalid address"; - end - elseif proto ~= "IPv4" and proto ~= "IPv6" then - return nil, "invalid protocol"; - end local zone; - if proto == "IPv6" and ipStr:find('%', 1, true) then + if (not proto or proto == "IPv6") and ipStr:find('%', 1, true) then ipStr, zone = ipStr:match("^(.-)%%(.*)"); end - if proto == "IPv6" and ipStr:find('.', 1, true) then - local changed; - ipStr, changed = ipStr:gsub(":(%d+)%.(%d+)%.(%d+)%.(%d+)$", function(a,b,c,d) - return (":%04X:%04X"):format(a*256+b,c*256+d); - end); - if changed ~= 1 then return nil, "invalid-address"; end + + local packed, err = net.pton(ipStr); + if not packed then return packed, err end + if proto == "IPv6" and #packed ~= 16 then + return nil, "invalid-ipv6"; + elseif proto == "IPv4" and #packed ~= 4 then + return nil, "invalid-ipv4"; + elseif not proto then + if #packed == 16 then + proto = "IPv6"; + elseif #packed == 4 then + proto = "IPv4"; + else + return nil, "unknown protocol"; + end + elseif proto ~= "IPv6" and proto ~= "IPv4" then + return nil, "invalid protocol"; end - return setmetatable({ addr = ipStr, proto = proto, zone = zone }, ip_mt); + return setmetatable({ addr = ipStr, packed = packed, proto = proto, zone = zone }, ip_mt); +end + +function ip_methods:normal() + return net.ntop(self.packed); end -local function toBits(ip) - local result = ""; - local fields = {}; +function ip_methods.bits(ip) + return hex.to(ip.packed):upper():gsub(".", hex2bits); +end + +function ip_methods.bits_full(ip) if ip.proto == "IPv4" then ip = ip.toV4mapped; end - ip = (ip.addr):upper(); - ip:gsub("([^:]*):?", function (c) fields[#fields + 1] = c end); - if not ip:match(":$") then fields[#fields] = nil; end - for i, field in ipairs(fields) do - if field:len() == 0 and i ~= 1 and i ~= #fields then - for _ = 1, 16 * (9 - #fields) do - result = result .. "0"; - end - else - for _ = 1, 4 - field:len() do - result = result .. "0000"; - end - for j = 1, field:len() do - result = result .. hex2bits[field:sub(j, j)]; - end - end - end - return result; + return ip.bits; end +local match; + local function commonPrefixLength(ipA, ipB) - ipA, ipB = toBits(ipA), toBits(ipB); + ipA, ipB = ipA.bits_full, ipB.bits_full; for i = 1, 128 do if ipA:sub(i,i) ~= ipB:sub(i,i) then return i-1; @@ -76,56 +83,60 @@ local function commonPrefixLength(ipA, ipB) return 128; end +-- Instantiate once +local loopback = new_ip("::1"); +local loopback4 = new_ip("127.0.0.0"); +local sixtofour = new_ip("2002::"); +local teredo = new_ip("2001::"); +local linklocal = new_ip("fe80::"); +local linklocal4 = new_ip("169.254.0.0"); +local uniquelocal = new_ip("fc00::"); +local sitelocal = new_ip("fec0::"); +local sixbone = new_ip("3ffe::"); +local defaultunicast = new_ip("::"); +local multicast = new_ip("ff00::"); +local ipv6mapped = new_ip("::ffff:0:0"); + local function v4scope(ip) - local fields = {}; - ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); - -- Loopback: - if fields[1] == 127 then + if match(ip, loopback4, 8) then return 0x2; - -- Link-local unicast: - elseif fields[1] == 169 and fields[2] == 254 then + elseif match(ip, linklocal4) then return 0x2; - -- Global unicast: - else + else -- Global unicast return 0xE; end end local function v6scope(ip) - -- Loopback: - if ip:match("^[0:]*1$") then + if ip == loopback then return 0x2; - -- Link-local unicast: - elseif ip:match("^[Ff][Ee][89ABab]") then + elseif match(ip, linklocal, 10) then return 0x2; - -- Site-local unicast: - elseif ip:match("^[Ff][Ee][CcDdEeFf]") then + elseif match(ip, sitelocal, 10) then return 0x5; - -- Multicast: - elseif ip:match("^[Ff][Ff]") then - return tonumber("0x"..ip:sub(4,4)); - -- Global unicast: - else + elseif match(ip, multicast, 10) then + return ip.packed:byte(2) % 0x10; + else -- Global unicast return 0xE; end end local function label(ip) - if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then + if ip == loopback then return 0; - elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then + elseif match(ip, sixtofour, 16) then return 2; - elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + elseif match(ip, teredo, 32) then return 5; - elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + elseif match(ip, uniquelocal, 7) then return 13; - elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + elseif match(ip, sitelocal, 10) then return 11; - elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + elseif match(ip, sixbone, 16) then return 12; - elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then + elseif match(ip, defaultunicast, 96) then return 3; - elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then + elseif match(ip, ipv6mapped, 96) then return 4; else return 1; @@ -133,91 +144,67 @@ local function label(ip) end local function precedence(ip) - if commonPrefixLength(ip, new_ip("::1", "IPv6")) == 128 then + if ip == loopback then return 50; - elseif commonPrefixLength(ip, new_ip("2002::", "IPv6")) >= 16 then + elseif match(ip, sixtofour, 16) then return 30; - elseif commonPrefixLength(ip, new_ip("2001::", "IPv6")) >= 32 then + elseif match(ip, teredo, 32) then return 5; - elseif commonPrefixLength(ip, new_ip("fc00::", "IPv6")) >= 7 then + elseif match(ip, uniquelocal, 7) then return 3; - elseif commonPrefixLength(ip, new_ip("fec0::", "IPv6")) >= 10 then + elseif match(ip, sitelocal, 10) then return 1; - elseif commonPrefixLength(ip, new_ip("3ffe::", "IPv6")) >= 16 then + elseif match(ip, sixbone, 16) then return 1; - elseif commonPrefixLength(ip, new_ip("::", "IPv6")) >= 96 then + elseif match(ip, defaultunicast, 96) then return 1; - elseif commonPrefixLength(ip, new_ip("::ffff:0:0", "IPv6")) >= 96 then + elseif match(ip, ipv6mapped, 96) then return 35; else return 40; end end -local function toV4mapped(ip) - local fields = {}; - local ret = "::ffff:"; - ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); - ret = ret .. ("%02x"):format(fields[1]); - ret = ret .. ("%02x"):format(fields[2]); - ret = ret .. ":" - ret = ret .. ("%02x"):format(fields[3]); - ret = ret .. ("%02x"):format(fields[4]); - return new_ip(ret, "IPv6"); -end - function ip_methods:toV4mapped() if self.proto ~= "IPv4" then return nil, "No IPv4 address" end - local value = toV4mapped(self.addr); - self.toV4mapped = value; + local value = new_ip("::ffff:" .. self.normal); return value; end function ip_methods:label() - local value; if self.proto == "IPv4" then - value = label(self.toV4mapped); + return label(self.toV4mapped); else - value = label(self); + return label(self); end - self.label = value; - return value; end function ip_methods:precedence() - local value; if self.proto == "IPv4" then - value = precedence(self.toV4mapped); + return precedence(self.toV4mapped); else - value = precedence(self); + return precedence(self); end - self.precedence = value; - return value; end function ip_methods:scope() - local value; if self.proto == "IPv4" then - value = v4scope(self.addr); + return v4scope(self); else - value = v6scope(self.addr); + return v6scope(self); end - self.scope = value; - return value; end +local rfc1918_8 = new_ip("10.0.0.0"); +local rfc1918_12 = new_ip("172.16.0.0"); +local rfc1918_16 = new_ip("192.168.0.0"); +local rfc6598 = new_ip("100.64.0.0"); + function ip_methods:private() local private = self.scope ~= 0xE; if not private and self.proto == "IPv4" then - local ip = self.addr; - local fields = {}; - ip:gsub("([^.]*).?", function (c) fields[#fields + 1] = tonumber(c) end); - if fields[1] == 127 or fields[1] == 10 or (fields[1] == 192 and fields[2] == 168) - or (fields[1] == 172 and (fields[2] >= 16 or fields[2] <= 32)) then - private = true; - end + return match(self, rfc1918_8, 8) or match(self, rfc1918_12, 12) or match(self, rfc1918_16) or match(self, rfc6598, 10); end - self.private = private; return private; end @@ -231,15 +218,26 @@ local function parse_cidr(cidr) return new_ip(cidr), bits; end -local function match(ipA, ipB, bits) - local common_bits = commonPrefixLength(ipA, ipB); - if bits and ipB.proto == "IPv4" then - common_bits = common_bits - 96; -- v6 mapped addresses always share these bits +function match(ipA, ipB, bits) + if not bits or bits >= 128 or ipB.proto == "IPv4" and bits >= 32 then + return ipA == ipB; + elseif bits < 1 then + return true; + end + if ipA.proto ~= ipB.proto then + if ipA.proto == "IPv4" then + ipA = ipA.toV4mapped; + elseif ipB.proto == "IPv4" then + ipB = ipB.toV4mapped; + bits = bits + (128 - 32); + end end - return common_bits >= (bits or 128); + return ipA.bits:sub(1, bits) == ipB.bits:sub(1, bits); end -return {new_ip = new_ip, +return { + new_ip = new_ip, commonPrefixLength = commonPrefixLength, parse_cidr = parse_cidr, - match=match}; + match = match, +}; diff --git a/util/iterators.lua b/util/iterators.lua index bd150ff2..5d16d8c1 100644 --- a/util/iterators.lua +++ b/util/iterators.lua @@ -12,8 +12,13 @@ local it = {}; local t_insert = table.insert; local select, next = select, next; -local unpack = table.unpack or unpack; --luacheck: ignore 113 -local pack = table.pack or function (...) return { n = select("#", ...), ... }; end +local unpack = table.unpack or unpack; --luacheck: ignore 113 143 +local pack = table.pack or function (...) return { n = select("#", ...), ... }; end -- luacheck: ignore 143 +local type = type; +local table, setmetatable = table, setmetatable; + +local _ENV = nil; +--luacheck: std none -- Reverse an iterator function it.reverse(f, s, var) @@ -184,4 +189,45 @@ function it.to_table(f, s, var) return t; end +local function _join_iter(j_s, j_var) + local iterators, current_idx = j_s[1], j_s[2]; + local f, s, var = unpack(iterators[current_idx], 1, 3); + if j_var ~= nil then + var = j_var; + end + local ret = pack(f(s, var)); + local var1 = ret[1]; + if var1 == nil then + -- End of this iterator, advance to next + if current_idx == #iterators then + -- No more iterators, return nil + return; + end + j_s[2] = current_idx + 1; + return _join_iter(j_s); + end + return unpack(ret, 1, ret.n); +end +local join_methods = {}; +local join_mt = { + __index = join_methods; + __call = function (t, s, var) --luacheck: ignore 212/t + return _join_iter(s, var); + end; +}; + +function join_methods:append(f, s, var) + table.insert(self, { f, s, var }); + return self, { self, 1 }; +end + +function join_methods:prepend(f, s, var) + table.insert(self, { f, s, var }, 1); + return self, { self, 1 }; +end + +function it.join(f, s, var) + return setmetatable({ {f, s, var} }, join_mt); +end + return it; diff --git a/util/jid.lua b/util/jid.lua index f402b7f4..37c48193 100644 --- a/util/jid.lua +++ b/util/jid.lua @@ -25,6 +25,7 @@ local unescapes = {}; for k,v in pairs(escapes) do unescapes[v] = k; end local _ENV = nil; +-- luacheck: std none local function split(jid) if not jid then return; end diff --git a/util/json.lua b/util/json.lua index cba54e8e..05af709a 100644 --- a/util/json.lua +++ b/util/json.lua @@ -27,9 +27,6 @@ module.null = null; local escapes = { ["\""] = "\\\"", ["\\"] = "\\\\", ["\b"] = "\\b", ["\f"] = "\\f", ["\n"] = "\\n", ["\r"] = "\\r", ["\t"] = "\\t"}; -local unescapes = { - ["\""] = "\"", ["\\"] = "\\", ["/"] = "/", - b = "\b", f = "\f", n = "\n", r = "\r", t = "\t"}; for i=0,31 do local ch = s_char(i); if not escapes[ch] then escapes[ch] = ("\\u%.4X"):format(i); end @@ -263,8 +260,9 @@ end local function _unescape_func(x) x = x:match("%x%x%x%x", 3); if x then - --if x >= 0xD800 and x <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair - return codepoint_to_utf8(tonumber(x, 16)); + local codepoint = tonumber(x, 16) + if codepoint >= 0xD800 and codepoint <= 0xDFFF then _unescape_error = true; end -- bad surrogate pair + return codepoint_to_utf8(codepoint); end _unescape_error = true; end @@ -276,7 +274,7 @@ function _readstring(json, index) --if s:find("[%z-\31]") then return nil, "control char in string"; end -- FIXME handle control characters _unescape_error = nil; - --s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func); + s = s:gsub("\\u[dD][89abAB]%x%x\\u[dD][cdefCDEF]%x%x", _unescape_surrogate_func); -- FIXME handle escapes beyond BMP s = s:gsub("\\u.?.?.?.?", _unescape_func); if _unescape_error then return nil, "invalid escape"; end diff --git a/util/logger.lua b/util/logger.lua index e72b29bc..20a5cef2 100644 --- a/util/logger.lua +++ b/util/logger.lua @@ -8,8 +8,11 @@ -- luacheck: ignore 213/level local pairs = pairs; +local ipairs = ipairs; +local require = require; local _ENV = nil; +-- luacheck: std none local level_sinks = {}; @@ -67,10 +70,21 @@ local function add_level_sink(level, sink_function) end end +local function add_simple_sink(simple_sink_function, levels) + local format = require "util.format".format; + local function sink_function(name, level, msg, ...) + return simple_sink_function(name, level, format(msg, ...)); + end + for _, level in ipairs(levels or {"debug", "info", "warn", "error"}) do + add_level_sink(level, sink_function); + end +end + return { init = init; make_logger = make_logger; reset = reset; add_level_sink = add_level_sink; + add_simple_sink = add_simple_sink; new = make_logger; }; diff --git a/util/multitable.lua b/util/multitable.lua index e4321d3d..8d32ed8a 100644 --- a/util/multitable.lua +++ b/util/multitable.lua @@ -9,9 +9,10 @@ local select = select; local t_insert = table.insert; local pairs, next, type = pairs, next, type; -local unpack = table.unpack or unpack; --luacheck: ignore 113 +local unpack = table.unpack or unpack; --luacheck: ignore 113 143 local _ENV = nil; +-- luacheck: std none local function get(self, ...) local t = self.data; @@ -132,7 +133,7 @@ local function iter(self, ...) local maxdepth = select("#", ...); local stack = { self.data }; local keys = { }; - local function it(self) + local function it(self) -- luacheck: ignore 432/self local depth = #stack; local key = next(stack[depth], keys[depth]); if key == nil then -- Go up the stack diff --git a/util/openssl.lua b/util/openssl.lua index 703c6d15..32b5aea7 100644 --- a/util/openssl.lua +++ b/util/openssl.lua @@ -114,7 +114,7 @@ function ssl_config:add_xmppAddr(host) s_format("%s;%s", oid_xmppaddr, utf8string(host))); end -function ssl_config:from_prosody(hosts, config, certhosts) +function ssl_config:from_prosody(hosts, config, certhosts) -- luacheck: ignore 431/config -- TODO Decide if this should go elsewhere local found_matching_hosts = false; for i = 1, #certhosts do diff --git a/util/pluginloader.lua b/util/pluginloader.lua index 004855f0..9ab8f245 100644 --- a/util/pluginloader.lua +++ b/util/pluginloader.lua @@ -5,6 +5,7 @@ -- This project is MIT/X11 licensed. Please see the -- COPYING file in the source package for more information. -- +-- luacheck: ignore 113/CFG_PLUGINDIR local dir_sep, path_sep = package.config:match("^(%S+)%s(%S+)"); local plugin_dir = {}; diff --git a/util/presence.lua b/util/presence.lua index f6370354..8d1ae2d9 100644 --- a/util/presence.lua +++ b/util/presence.lua @@ -13,7 +13,6 @@ local function select_top_resources(user) local recipients = {}; for _, session in pairs(user.sessions) do -- find resource with greatest priority if session.presence then - -- TODO check active privacy list for session local p = session.priority; if p > priority then priority = p; diff --git a/util/prosodyctl.lua b/util/prosodyctl.lua index 8ae051ae..5f0c4d12 100644 --- a/util/prosodyctl.lua +++ b/util/prosodyctl.lua @@ -24,8 +24,6 @@ local io, os = io, os; local print = print; local tonumber = tonumber; -local CFG_SOURCEDIR = _G.CFG_SOURCEDIR; - local _G = _G; local prosody = prosody; @@ -66,7 +64,10 @@ local function getline() end local function getpass() - local stty_ret = os.execute("stty -echo 2>/dev/null"); + local stty_ret, _, status_code = os.execute("stty -echo 2>/dev/null"); + if status_code then -- COMPAT w/ Lua 5.1 + stty_ret = status_code; + end if stty_ret ~= 0 then io.write("\027[08m"); -- ANSI 'hidden' text attribute end @@ -189,8 +190,8 @@ local function getpid() pidfile = config.resolve_relative_path(prosody.paths.data, pidfile); - local modules_enabled = set.new(config.get("*", "modules_disabled")); - if prosody.platform ~= "posix" or modules_enabled:contains("posix") then + local modules_disabled = set.new(config.get("*", "modules_disabled")); + if prosody.platform ~= "posix" or modules_disabled:contains("posix") then return false, "no-posix"; end @@ -228,7 +229,7 @@ local function isrunning() return true, signal.kill(pid, 0) == 0; end -local function start() +local function start(source_dir) local ok, ret = isrunning(); if not ok then return ok, ret; @@ -236,10 +237,10 @@ local function start() if ret then return false, "already-running"; end - if not CFG_SOURCEDIR then + if not source_dir then os.execute("./prosody"); else - os.execute(CFG_SOURCEDIR.."/../../bin/prosody"); + os.execute(source_dir.."/../../bin/prosody"); end return true; end diff --git a/util/pubsub.lua b/util/pubsub.lua index 1db917d8..6c5de919 100644 --- a/util/pubsub.lua +++ b/util/pubsub.lua @@ -1,32 +1,71 @@ local events = require "util.events"; local cache = require "util.cache"; -local service = {}; -local service_mt = { __index = service }; +local service_mt = {}; -local default_config = { __index = { - itemstore = function (config) return cache.new(tonumber(config["pubsub#max_items"])) end; +local default_config = { + itemstore = function (config, _) return cache.new(config["max_items"]) end; broadcaster = function () end; + itemcheck = function () return true; end; get_affiliation = function () end; + normalize_jid = function (jid) return jid; end; capabilities = {}; -} }; -local default_node_config = { __index = { - ["pubsub#max_items"] = "20"; -} }; +}; +local default_config_mt = { __index = default_config }; + +local default_node_config = { + ["persist_items"] = false; + ["max_items"] = 20; +}; +local default_node_config_mt = { __index = default_node_config }; + +-- Storage helper functions + +local function load_node_from_store(nodestore, node_name) + local node = nodestore:get(node_name); + node.config = setmetatable(node.config or {}, default_node_config_mt); + return node; +end +local function save_node_to_store(nodestore, node) + return nodestore:set(node.name, { + name = node.name; + config = node.config; + subscribers = node.subscribers; + affiliations = node.affiliations; + }); +end + +-- Create and return a new service object local function new(config) config = config or {}; - return setmetatable({ - config = setmetatable(config, default_config); - node_defaults = setmetatable(config.node_defaults or {}, default_node_config); + + local service = setmetatable({ + config = setmetatable(config, default_config_mt); + node_defaults = setmetatable(config.node_defaults or {}, default_node_config_mt); affiliations = {}; subscriptions = {}; nodes = {}; data = {}; events = events.new(); }, service_mt); + + -- Load nodes from storage, if we have a store and it supports iterating over stored items + if config.nodestore and config.nodestore.users then + for node_name in config.nodestore:users() do + service.nodes[node_name] = load_node_from_store(config.nodestore, node_name); + service.data[node_name] = config.itemstore(service.nodes[node_name].config, node_name); + end + end + + return service; end +--- Service methods + +local service = {}; +service_mt.__index = service; + function service:jids_equal(jid1, jid2) local normalize = self.config.normalize_jid; return normalize(jid1) == normalize(jid2); @@ -36,7 +75,8 @@ function service:may(node, actor, action) if actor == true then return true; end local node_obj = self.nodes[node]; - local node_aff = node_obj and node_obj.affiliations[actor]; + local node_aff = node_obj and (node_obj.affiliations[actor] + or node_obj.affiliations[self.config.normalize_jid(actor)]); local service_aff = self.affiliations[actor] or self.config.get_affiliation(actor, node, action) or "none"; @@ -76,6 +116,7 @@ function service:set_affiliation(node, actor, jid, affiliation) if not node_obj then return false, "item-not-found"; end + jid = self.config.normalize_jid(jid); node_obj.affiliations[jid] = affiliation; local _, jid_sub = self:get_subscription(node, true, jid); if not jid_sub and not self:may(node, jid, "be_unsubscribed") then @@ -176,18 +217,6 @@ function service:remove_subscription(node, actor, jid) return true; end -function service:remove_all_subscriptions(actor, jid) - local normal_jid = self.config.normalize_jid(jid); - local subs = self.subscriptions[normal_jid] - subs = subs and subs[jid]; - if subs then - for node in pairs(subs) do - self:remove_subscription(node, true, jid); - end - end - return true; -end - function service:get_subscription(node, actor, jid) -- Access checking local cap; @@ -223,14 +252,27 @@ function service:create(node, actor, options) config = setmetatable(options or {}, {__index=self.node_defaults}); affiliations = {}; }; - self.data[node] = self.config.itemstore(self.nodes[node].config); + + if self.config.nodestore then + local ok, err = save_node_to_store(self.config.nodestore, self.nodes[node]); + if not ok then + self.nodes[node] = nil; + return ok, err; + end + end + + self.data[node] = self.config.itemstore(self.nodes[node].config, node); self.events.fire_event("node-created", { node = node, actor = actor }); - local ok, err = self:set_affiliation(node, true, actor, "owner"); - if not ok then - self.nodes[node] = nil; - self.data[node] = nil; + if actor ~= true then + local ok, err = self:set_affiliation(node, true, actor, "owner"); + if not ok then + self.nodes[node] = nil; + self.data[node] = nil; + return ok, err; + end end - return ok, err; + + return true; end function service:delete(node, actor) @@ -244,9 +286,12 @@ function service:delete(node, actor) return false, "item-not-found"; end self.nodes[node] = nil; + if self.data[node] and self.data[node].clear then + self.data[node]:clear(); + end self.data[node] = nil; self.events.fire_event("node-deleted", { node = node, actor = actor }); - self.config.broadcaster("delete", node, node_obj.subscribers); + self.config.broadcaster("delete", node, node_obj.subscribers, nil, actor, node_obj, self); return true; end @@ -267,13 +312,17 @@ function service:publish(node, actor, id, item) end node_obj = self.nodes[node]; end + if not self.config.itemcheck(item) then + return nil, "internal-server-error"; + end local node_data = self.data[node]; local ok = node_data:set(id, item); if not ok then return nil, "internal-server-error"; end + if type(ok) == "string" then id = ok; end self.events.fire_event("item-published", { node = node, actor = actor, id = id, item = item }); - self.config.broadcaster("items", node, node_obj.subscribers, item, actor); + self.config.broadcaster("items", node, node_obj.subscribers, item, actor, node_obj, self); return true; end @@ -293,7 +342,7 @@ function service:retract(node, actor, id, retract) end self.events.fire_event("item-retracted", { node = node, actor = actor, id = id }); if retract then - self.config.broadcaster("items", node, node_obj.subscribers, retract); + self.config.broadcaster("items", node, node_obj.subscribers, retract, actor, node_obj, self); end return true end @@ -308,10 +357,14 @@ function service:purge(node, actor, notify) if not node_obj then return false, "item-not-found"; end - self.data[node] = self.config.itemstore(self.nodes[node].config); + if self.data[node] and self.data[node].clear then + self.data[node]:clear() + else + self.data[node] = self.config.itemstore(self.nodes[node].config, node); + end self.events.fire_event("node-purged", { node = node, actor = actor }); if notify then - self.config.broadcaster("purge", node, node_obj.subscribers); + self.config.broadcaster("purge", node, node_obj.subscribers, nil, actor, node_obj, self); end return true end @@ -327,7 +380,11 @@ function service:get_items(node, actor, id) return false, "item-not-found"; end if id then -- Restrict results to a single specific item - return true, { id, [id] = self.data[node]:get(id) }; + local with_id = self.data[node]:get(id); + if not with_id then + return true, { }; + end + return true, { id, [id] = with_id }; else local data = {} for key, value in self.data[node]:items() do @@ -338,6 +395,15 @@ function service:get_items(node, actor, id) end end +function service:get_last_item(node, actor) + -- Access checking + if not self:may(node, actor, "get_items") then + return false, "forbidden"; + end + -- + return true, self.data[node]:tail(); +end + function service:get_nodes(actor) -- Access checking if not self:may(nil, actor, "get_nodes") then @@ -421,14 +487,14 @@ function service:set_node_config(node, actor, new_config) return false, "item-not-found"; end - for k,v in pairs(new_config) do - node_obj.config[k] = v; + if new_config["persist_items"] ~= node_obj.config["persist_items"] then + self.data[node] = self.config.itemstore(self.nodes[node].config, node); + elseif new_config["max_items"] ~= node_obj.config["max_items"] then + self.data[node]:resize(new_config["max_items"]); end - local new_data = self.config.itemstore(self.nodes[node].config); - for key, value in self.data[node]:items() do - new_data:set(key, value); - end - self.data[node] = new_data; + + node_obj.config = setmetatable(new_config, {__index=self.node_defaults}); + return true; end diff --git a/util/random.lua b/util/random.lua index b2d0000d..d8a84514 100644 --- a/util/random.lua +++ b/util/random.lua @@ -11,9 +11,6 @@ if ok then return crand; end local urandom, urandom_err = io.open("/dev/urandom", "r"); -local function seed() -end - local function bytes(n) return urandom:read(n); end @@ -25,7 +22,6 @@ if not urandom then end return { - seed = seed; bytes = bytes; _source = "/dev/urandom"; }; diff --git a/util/sasl.lua b/util/sasl.lua index 5845f34a..50851405 100644 --- a/util/sasl.lua +++ b/util/sasl.lua @@ -20,6 +20,7 @@ local assert = assert; local require = require; local _ENV = nil; +-- luacheck: std none --[[ Authentication Backend Prototypes: @@ -42,7 +43,7 @@ Example: local method = {}; method.__index = method; -local mechanisms = {}; +local registered_mechanisms = {}; local backend_mechanism = {}; local mechanism_channelbindings = {}; @@ -52,7 +53,7 @@ local function registerMechanism(name, backends, f, cb_backends) assert(type(backends) == "string" or type(backends) == "table", "Parameter backends MUST be either a string or a table."); assert(type(f) == "function", "Parameter f MUST be a function."); if cb_backends then assert(type(cb_backends) == "table"); end - mechanisms[name] = f + registered_mechanisms[name] = f if cb_backends then mechanism_channelbindings[name] = {}; for _, cb_name in ipairs(cb_backends) do @@ -70,7 +71,7 @@ local function new(realm, profile) local mechanisms = profile.mechanisms; if not mechanisms then mechanisms = {}; - for backend, f in pairs(profile) do + for backend in pairs(profile) do if backend_mechanism[backend] then for _, mechanism in ipairs(backend_mechanism[backend]) do mechanisms[mechanism] = true; @@ -128,7 +129,7 @@ end -- feed new messages to process into the library function method:process(message) --if message == "" or message == nil then return "failure", "malformed-request" end - return mechanisms[self.selected](self, message); + return registered_mechanisms[self.selected](self, message); end -- load the mechanisms diff --git a/util/sasl/anonymous.lua b/util/sasl/anonymous.lua index 6201db32..de98a5e2 100644 --- a/util/sasl/anonymous.lua +++ b/util/sasl/anonymous.lua @@ -12,9 +12,10 @@ -- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -local generate_uuid = require "util.uuid".generate; +local generate_random_id = require "util.id".medium; local _ENV = nil; +-- luacheck: std none --========================= --SASL ANONYMOUS according to RFC 4505 @@ -28,10 +29,10 @@ anonymous: end ]] -local function anonymous(self, message) +local function anonymous(self, message) -- luacheck: ignore 212/message local username; repeat - username = generate_uuid(); + username = generate_random_id():lower(); until self.profile.anonymous(self, username, self.realm); self.username = username; return "success" diff --git a/util/sasl/digest-md5.lua b/util/sasl/digest-md5.lua index 695dd2a3..7542a037 100644 --- a/util/sasl/digest-md5.lua +++ b/util/sasl/digest-md5.lua @@ -26,6 +26,7 @@ local generate_uuid = require "util.uuid".generate; local nodeprep = require "util.encodings".stringprep.nodeprep; local _ENV = nil; +-- luacheck: std none --========================= --SASL DIGEST-MD5 according to RFC 2831 diff --git a/util/sasl/external.lua b/util/sasl/external.lua index 5ba90190..ce50743e 100644 --- a/util/sasl/external.lua +++ b/util/sasl/external.lua @@ -1,6 +1,7 @@ local saslprep = require "util.encodings".stringprep.saslprep; local _ENV = nil; +-- luacheck: std none local function external(self, message) message = saslprep(message); diff --git a/util/sasl/plain.lua b/util/sasl/plain.lua index cd59b1ac..00c6bd20 100644 --- a/util/sasl/plain.lua +++ b/util/sasl/plain.lua @@ -17,6 +17,7 @@ local nodeprep = require "util.encodings".stringprep.nodeprep; local log = require "util.logger".init("sasl"); local _ENV = nil; +-- luacheck: std none -- ================================ -- SASL PLAIN according to RFC 4616 diff --git a/util/sasl/scram.lua b/util/sasl/scram.lua index 4e20dbb9..043f328b 100644 --- a/util/sasl/scram.lua +++ b/util/sasl/scram.lua @@ -26,6 +26,7 @@ local char = string.char; local byte = string.byte; local _ENV = nil; +-- luacheck: std none --========================= --SASL SCRAM-SHA-1 according to RFC 5802 @@ -46,7 +47,18 @@ Supported Channel Binding Backends local default_i = 4096 -local xor_map = {0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;1;0;3;2;5;4;7;6;9;8;11;10;13;12;15;14;2;3;0;1;6;7;4;5;10;11;8;9;14;15;12;13;3;2;1;0;7;6;5;4;11;10;9;8;15;14;13;12;4;5;6;7;0;1;2;3;12;13;14;15;8;9;10;11;5;4;7;6;1;0;3;2;13;12;15;14;9;8;11;10;6;7;4;5;2;3;0;1;14;15;12;13;10;11;8;9;7;6;5;4;3;2;1;0;15;14;13;12;11;10;9;8;8;9;10;11;12;13;14;15;0;1;2;3;4;5;6;7;9;8;11;10;13;12;15;14;1;0;3;2;5;4;7;6;10;11;8;9;14;15;12;13;2;3;0;1;6;7;4;5;11;10;9;8;15;14;13;12;3;2;1;0;7;6;5;4;12;13;14;15;8;9;10;11;4;5;6;7;0;1;2;3;13;12;15;14;9;8;11;10;5;4;7;6;1;0;3;2;14;15;12;13;10;11;8;9;6;7;4;5;2;3;0;1;15;14;13;12;11;10;9;8;7;6;5;4;3;2;1;0;}; +local xor_map = { + 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,1,0,3,2,5,4,7,6,9,8,11,10, + 13,12,15,14,2,3,0,1,6,7,4,5,10,11,8,9,14,15,12,13,3,2,1,0,7,6,5, + 4,11,10,9,8,15,14,13,12,4,5,6,7,0,1,2,3,12,13,14,15,8,9,10,11,5, + 4,7,6,1,0,3,2,13,12,15,14,9,8,11,10,6,7,4,5,2,3,0,1,14,15,12,13, + 10,11,8,9,7,6,5,4,3,2,1,0,15,14,13,12,11,10,9,8,8,9,10,11,12,13, + 14,15,0,1,2,3,4,5,6,7,9,8,11,10,13,12,15,14,1,0,3,2,5,4,7,6,10, + 11,8,9,14,15,12,13,2,3,0,1,6,7,4,5,11,10,9,8,15,14,13,12,3,2,1, + 0,7,6,5,4,12,13,14,15,8,9,10,11,4,5,6,7,0,1,2,3,13,12,15,14,9,8, + 11,10,5,4,7,6,1,0,3,2,14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1,15, + 14,13,12,11,10,9,8,7,6,5,4,3,2,1,0, +}; local result = {}; local function binaryXOR( a, b ) @@ -148,7 +160,7 @@ local function scram_gen(hash_name, H_f, HMAC_f) end self.username = username; - -- retreive credentials + -- retrieve credentials local stored_key, server_key, salt, iteration_count; if self.profile.plain then local password, status = self.profile.plain(self, username, self.realm) @@ -237,10 +249,14 @@ end local function init(registerMechanism) local function registerSCRAMMechanism(hash_name, hash, hmac_hash) - registerMechanism("SCRAM-"..hash_name, {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash)); + registerMechanism("SCRAM-"..hash_name, + {"plain", "scram_"..(hashprep(hash_name))}, + scram_gen(hash_name:lower(), hash, hmac_hash)); -- register channel binding equivalent - registerMechanism("SCRAM-"..hash_name.."-PLUS", {"plain", "scram_"..(hashprep(hash_name))}, scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"}); + registerMechanism("SCRAM-"..hash_name.."-PLUS", + {"plain", "scram_"..(hashprep(hash_name))}, + scram_gen(hash_name:lower(), hash, hmac_hash), {"tls-unique"}); end registerSCRAMMechanism("SHA-1", sha1, hmac_sha1); diff --git a/util/sasl_cyrus.lua b/util/sasl_cyrus.lua index 4e9a4af5..a6bd0628 100644 --- a/util/sasl_cyrus.lua +++ b/util/sasl_cyrus.lua @@ -61,6 +61,7 @@ local sasl_errstring = { setmetatable(sasl_errstring, { __index = function() return "undefined error!" end }); local _ENV = nil; +-- luacheck: std none local method = {}; method.__index = method; diff --git a/util/serialization.lua b/util/serialization.lua index 206f5fbb..54c8110f 100644 --- a/util/serialization.lua +++ b/util/serialization.lua @@ -21,6 +21,7 @@ local log = require "util.logger".init("serialization"); local envload = require"util.envload".envload; local _ENV = nil; +-- luacheck: std none local indent = function(i) return string_rep("\t", i); diff --git a/util/set.lua b/util/set.lua index c136a522..a4f20138 100644 --- a/util/set.lua +++ b/util/set.lua @@ -11,8 +11,9 @@ local ipairs, pairs, setmetatable, next, tostring = local t_concat = table.concat; local _ENV = nil; +-- luacheck: std none -local set_mt = {}; +local set_mt = { __name = "set" }; function set_mt.__call(set, _, k) return next(set._items, k); end diff --git a/util/sql.lua b/util/sql.lua index d964025e..67a5d683 100644 --- a/util/sql.lua +++ b/util/sql.lua @@ -1,11 +1,10 @@ local setmetatable, getmetatable = setmetatable, getmetatable; -local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113 -local tonumber, tostring = tonumber, tostring; +local ipairs, unpack, select = ipairs, table.unpack or unpack, select; --luacheck: ignore 113 143 +local tostring = tostring; local type = type; local assert, pcall, xpcall, debug_traceback = assert, pcall, xpcall, debug.traceback; local t_concat = table.concat; -local s_char = string.char; local log = require "util.logger".init("sql"); local DBI = require "DBI"; @@ -15,6 +14,7 @@ DBI.Drivers(); local build_url = require "socket.url".build; local _ENV = nil; +-- luacheck: std none local column_mt = {}; local table_mt = {}; @@ -58,9 +58,6 @@ table_mt.__index = {}; function table_mt.__index:create(engine) return engine:_create_table(self); end -function table_mt:__call(...) - -- TODO -end function column_mt:__tostring() return 'Column{ name="'..self.name..'", type="'..self.type..'" }' end @@ -71,31 +68,6 @@ function index_mt:__tostring() -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }' end -local function urldecode(s) return s and (s:gsub("%%(%x%x)", function (c) return s_char(tonumber(c,16)); end)); end -local function parse_url(url) - local scheme, secondpart, database = url:match("^([%w%+]+)://([^/]*)/?(.*)"); - assert(scheme, "Invalid URL format"); - local username, password, host, port; - local authpart, hostpart = secondpart:match("([^@]+)@([^@+])"); - if not authpart then hostpart = secondpart; end - if authpart then - username, password = authpart:match("([^:]*):(.*)"); - username = username or authpart; - password = password and urldecode(password); - end - if hostpart then - host, port = hostpart:match("([^:]*):(.*)"); - host = host or hostpart; - port = port and assert(tonumber(port), "Invalid URL format"); - end - return { - scheme = scheme:lower(); - username = username; password = password; - host = host; port = port; - database = #database > 0 and database or nil; - }; -end - local engine = {}; function engine:connect() if self.conn then return true; end @@ -123,7 +95,7 @@ function engine:connect() end return true; end -function engine:onconnect() +function engine:onconnect() -- luacheck: ignore 212/self -- Override from create_engine() end @@ -148,6 +120,7 @@ function engine:execute(sql, ...) prepared[sql] = stmt; end + -- luacheck: ignore 411/success local success, err = stmt:execute(...); if not success then return success, err; end return stmt; @@ -161,14 +134,14 @@ local result_mt = { __index = { local function debugquery(where, sql, ...) local i = 0; local a = {...} sql = sql:gsub("\n?\t+", " "); - log("debug", "[%s] %s", where, sql:gsub("%?", function () + log("debug", "[%s] %s", where, (sql:gsub("%?", function () i = i + 1; local v = a[i]; if type(v) == "string" then v = ("'%s'"):format(v:gsub("'", "''")); end return tostring(v); - end)); + end))); end function engine:execute_query(sql, ...) @@ -335,7 +308,12 @@ function engine:set_encoding() -- to UTF-8 local charset = "utf8"; if driver == "MySQL" then self:transaction(function() - for row in self:select"SELECT \"CHARACTER_SET_NAME\" FROM \"information_schema\".\"CHARACTER_SETS\" WHERE \"CHARACTER_SET_NAME\" LIKE 'utf8%' ORDER BY MAXLEN DESC LIMIT 1;" do + for row in self:select[[ + SELECT "CHARACTER_SET_NAME" + FROM "information_schema"."CHARACTER_SETS" + WHERE "CHARACTER_SET_NAME" LIKE 'utf8%' + ORDER BY MAXLEN DESC LIMIT 1; + ]] do charset = row and row[1] or charset; end end); @@ -379,7 +357,7 @@ local function db2uri(params) }; end -local function create_engine(self, params, onconnect) +local function create_engine(_, params, onconnect) return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt); end diff --git a/util/sslconfig.lua b/util/sslconfig.lua index 4c4e1d48..5c685f7d 100644 --- a/util/sslconfig.lua +++ b/util/sslconfig.lua @@ -8,6 +8,7 @@ local t_insert = table.insert; local setmetatable = setmetatable; local _ENV = nil; +-- luacheck: std none local handlers = { }; local finalisers = { }; diff --git a/util/stanza.lua b/util/stanza.lua index 2191fa8e..1f67c75b 100644 --- a/util/stanza.lua +++ b/util/stanza.lua @@ -7,6 +7,7 @@ -- +local error = error; local t_insert = table.insert; local t_remove = table.remove; local t_concat = table.concat; @@ -23,6 +24,8 @@ local s_sub = string.sub; local s_find = string.find; local os = os; +local valid_utf8 = require "util.encodings".utf8.valid; + local do_pretty_printing = not os.getenv("WINDIR"); local getstyle, getstring; if do_pretty_printing then @@ -37,12 +40,52 @@ end local xmlns_stanzas = "urn:ietf:params:xml:ns:xmpp-stanzas"; local _ENV = nil; +-- luacheck: std none -local stanza_mt = { __type = "stanza" }; +local stanza_mt = { __name = "stanza" }; stanza_mt.__index = stanza_mt; -local function new_stanza(name, attr) - local stanza = { name = name, attr = attr or {}, tags = {} }; +local function check_name(name, name_type) + if type(name) ~= "string" then + error("invalid "..name_type.." name: expected string, got "..type(name)); + elseif #name == 0 then + error("invalid "..name_type.." name: empty string"); + elseif s_find(name, "[<>& '\"]") then + error("invalid "..name_type.." name: contains invalid characters"); + elseif not valid_utf8(name) then + error("invalid "..name_type.." name: contains invalid utf8"); + end +end + +local function check_text(text, text_type) + if type(text) ~= "string" then + error("invalid "..text_type.." value: expected string, got "..type(text)); + elseif not valid_utf8(text) then + error("invalid "..text_type.." value: contains invalid utf8"); + end +end + +local function check_attr(attr) + if attr ~= nil then + if type(attr) ~= "table" then + error("invalid attributes, expected table got "..type(attr)); + end + for k, v in pairs(attr) do + check_name(k, "attribute"); + check_text(v, "attribute"); + if type(v) ~= "string" then + error("invalid attribute value for '"..k.."': expected string, got "..type(v)); + elseif not valid_utf8(v) then + error("invalid attribute value for '"..k.."': contains invalid utf8"); + end + end + end +end + +local function new_stanza(name, attr, namespaces) + check_name(name, "tag"); + check_attr(attr); + local stanza = { name = name, attr = attr or {}, namespaces = namespaces, tags = {} }; return setmetatable(stanza, stanza_mt); end @@ -58,8 +101,12 @@ function stanza_mt:body(text, attr) return self:tag("body", attr):text(text); end -function stanza_mt:tag(name, attrs) - local s = new_stanza(name, attrs); +function stanza_mt:text_tag(name, text, attr, namespaces) + return self:tag(name, attr, namespaces):text(text):up(); +end + +function stanza_mt:tag(name, attr, namespaces) + local s = new_stanza(name, attr, namespaces); local last_add = self.last_add; if not last_add then last_add = {}; self.last_add = last_add; end (last_add[#last_add] or self):add_direct_child(s); @@ -68,8 +115,10 @@ function stanza_mt:tag(name, attrs) end function stanza_mt:text(text) - local last_add = self.last_add; - (last_add and last_add[#last_add] or self):add_direct_child(text); + if text ~= nil and text ~= "" then + local last_add = self.last_add; + (last_add and last_add[#last_add] or self):add_direct_child(text); + end return self; end @@ -85,10 +134,13 @@ function stanza_mt:reset() end function stanza_mt:add_direct_child(child) - if type(child) == "table" then + if is_stanza(child) then t_insert(self.tags, child); + t_insert(self, child); + else + check_text(child, "text"); + t_insert(self, child); end - t_insert(self, child); end function stanza_mt:add_child(child) @@ -337,7 +389,12 @@ end local function clone(stanza) local attr, tags = {}, {}; for k,v in pairs(stanza.attr) do attr[k] = v; end - local new = { name = stanza.name, attr = attr, tags = tags }; + local old_namespaces, namespaces = stanza.namespaces; + if old_namespaces then + namespaces = {}; + for k,v in pairs(old_namespaces) do namespaces[k] = v; end + end + local new = { name = stanza.name, attr = attr, namespaces = namespaces, tags = tags }; for i=1,#stanza do local child = stanza[i]; if child.name then @@ -362,7 +419,13 @@ local function iq(attr) end local function reply(orig) - return new_stanza(orig.name, orig.attr and { to = orig.attr.from, from = orig.attr.to, id = orig.attr.id, type = ((orig.name == "iq" and "result") or orig.attr.type) }); + return new_stanza(orig.name, + orig.attr and { + to = orig.attr.from, + from = orig.attr.to, + id = orig.attr.id, + type = ((orig.name == "iq" and "result") or orig.attr.type) + }); end local xmpp_stanzas_attr = { xmlns = xmlns_stanzas }; diff --git a/util/startup.lua b/util/startup.lua new file mode 100644 index 00000000..451a0587 --- /dev/null +++ b/util/startup.lua @@ -0,0 +1,550 @@ +-- Ignore the CFG_* variables +-- luacheck: ignore 113/CFG_CONFIGDIR 113/CFG_SOURCEDIR 113/CFG_DATADIR 113/CFG_PLUGINDIR +local startup = {}; + +local prosody = { events = require "util.events".new() }; +local logger = require "util.logger"; +local log = logger.init("startup"); + +local config = require "core.configmanager"; + +local dependencies = require "util.dependencies"; + +local original_logging_config; + +function startup.read_config() + local filenames = {}; + + local filename; + if arg[1] == "--config" and arg[2] then + table.insert(filenames, arg[2]); + if CFG_CONFIGDIR then + table.insert(filenames, CFG_CONFIGDIR.."/"..arg[2]); + end + table.remove(arg, 1); table.remove(arg, 1); + elseif os.getenv("PROSODY_CONFIG") then -- Passed by prosodyctl + table.insert(filenames, os.getenv("PROSODY_CONFIG")); + else + table.insert(filenames, (CFG_CONFIGDIR or ".").."/prosody.cfg.lua"); + end + for _,_filename in ipairs(filenames) do + filename = _filename; + local file = io.open(filename); + if file then + file:close(); + prosody.config_file = filename; + CFG_CONFIGDIR = filename:match("^(.*)[\\/][^\\/]*$"); -- luacheck: ignore 111 + break; + end + end + prosody.config_file = filename + local ok, level, err = config.load(filename); + if not ok then + print("\n"); + print("**************************"); + if level == "parser" then + print("A problem occurred while reading the config file "..filename); + print(""); + local err_line, err_message = tostring(err):match("%[string .-%]:(%d*): (.*)"); + if err:match("chunk has too many syntax levels$") then + print("An Include statement in a config file is including an already-included"); + print("file and causing an infinite loop. An Include statement in a config file is..."); + else + print("Error"..(err_line and (" on line "..err_line) or "")..": "..(err_message or tostring(err))); + end + print(""); + elseif level == "file" then + print("Prosody was unable to find the configuration file."); + print("We looked for: "..filename); + print("A sample config file is included in the Prosody download called prosody.cfg.lua.dist"); + print("Copy or rename it to prosody.cfg.lua and edit as necessary."); + end + print("More help on configuring Prosody can be found at https://prosody.im/doc/configure"); + print("Good luck!"); + print("**************************"); + print(""); + os.exit(1); + end +end + +function startup.check_dependencies() + if not dependencies.check_dependencies() then + os.exit(1); + end +end + +-- luacheck: globals socket server + +function startup.load_libraries() + -- Load socket framework + -- luacheck: ignore 111/server 111/socket + socket = require "socket"; + server = require "net.server" +end + +function startup.init_logging() + -- Initialize logging + local loggingmanager = require "core.loggingmanager" + loggingmanager.reload_logging(); + prosody.events.add_handler("reopen-log-files", function () + loggingmanager.reload_logging(); + prosody.events.fire_event("logging-reloaded"); + end); +end + +function startup.log_dependency_warnings() + dependencies.log_warnings(); +end + +function startup.sanity_check() + for host, host_config in pairs(config.getconfig()) do + if host ~= "*" + and host_config.enabled ~= false + and not host_config.component_module then + return; + end + end + log("error", "No enabled VirtualHost entries found in the config file."); + log("error", "At least one active host is required for Prosody to function. Exiting..."); + os.exit(1); +end + +function startup.sandbox_require() + -- Replace require() with one that doesn't pollute _G, required + -- for neat sandboxing of modules + -- luacheck: ignore 113/getfenv 111/require + local _realG = _G; + local _real_require = require; + local getfenv = getfenv or function (f) + -- FIXME: This is a hack to replace getfenv() in Lua 5.2 + local name, env = debug.getupvalue(debug.getinfo(f or 1).func, 1); + if name == "_ENV" then + return env; + end + end + function require(...) -- luacheck: ignore 121 + local curr_env = getfenv(2); + local curr_env_mt = getmetatable(curr_env); + local _realG_mt = getmetatable(_realG); + if curr_env_mt and curr_env_mt.__index and not curr_env_mt.__newindex and _realG_mt then + local old_newindex, old_index; + old_newindex, _realG_mt.__newindex = _realG_mt.__newindex, curr_env; + old_index, _realG_mt.__index = _realG_mt.__index, function (_G, k) -- luacheck: ignore 212/_G + return rawget(curr_env, k); + end; + local ret = _real_require(...); + _realG_mt.__newindex = old_newindex; + _realG_mt.__index = old_index; + return ret; + end + return _real_require(...); + end +end + +function startup.set_function_metatable() + local mt = {}; + function mt.__index(f, upvalue) + local i, name, value = 0; + repeat + i = i + 1; + name, value = debug.getupvalue(f, i); + until name == upvalue or name == nil; + return value; + end + function mt.__newindex(f, upvalue, value) + local i, name = 0; + repeat + i = i + 1; + name = debug.getupvalue(f, i); + until name == upvalue or name == nil; + if name then + debug.setupvalue(f, i, value); + end + end + function mt.__tostring(f) + local info = debug.getinfo(f); + return ("function(%s:%d)"):format(info.short_src:match("[^\\/]*$"), info.linedefined); + end + debug.setmetatable(function() end, mt); +end + +function startup.detect_platform() + prosody.platform = "unknown"; + if os.getenv("WINDIR") then + prosody.platform = "windows"; + elseif package.config:sub(1,1) == "/" then + prosody.platform = "posix"; + end +end + +function startup.detect_installed() + prosody.installed = nil; + if CFG_SOURCEDIR and (prosody.platform == "windows" or CFG_SOURCEDIR:match("^/")) then + prosody.installed = true; + end +end + +function startup.init_global_state() + -- luacheck: ignore 121 + prosody.bare_sessions = {}; + prosody.full_sessions = {}; + prosody.hosts = {}; + + -- COMPAT: These globals are deprecated + -- luacheck: ignore 111/bare_sessions 111/full_sessions 111/hosts + bare_sessions = prosody.bare_sessions; + full_sessions = prosody.full_sessions; + hosts = prosody.hosts; + + prosody.paths = { source = CFG_SOURCEDIR, config = CFG_CONFIGDIR or ".", + plugins = CFG_PLUGINDIR or "plugins", data = "data" }; + + prosody.arg = _G.arg; + + _G.log = logger.init("general"); + prosody.log = logger.init("general"); + + startup.detect_platform(); + startup.detect_installed(); + _G.prosody = prosody; +end + +function startup.setup_datadir() + prosody.paths.data = config.get("*", "data_path") or CFG_DATADIR or "data"; +end + +function startup.setup_plugindir() + local custom_plugin_paths = config.get("*", "plugin_paths"); + if custom_plugin_paths then + local path_sep = package.config:sub(3,3); + -- path1;path2;path3;defaultpath... + -- luacheck: ignore 111 + CFG_PLUGINDIR = table.concat(custom_plugin_paths, path_sep)..path_sep..(CFG_PLUGINDIR or "plugins"); + prosody.paths.plugins = CFG_PLUGINDIR; + end +end + +function startup.chdir() + if prosody.installed then + -- Change working directory to data path. + require "lfs".chdir(prosody.paths.data); + end +end + +function startup.add_global_prosody_functions() + -- Function to reload the config file + function prosody.reload_config() + log("info", "Reloading configuration file"); + prosody.events.fire_event("reloading-config"); + local ok, level, err = config.load(prosody.config_file); + if not ok then + if level == "parser" then + log("error", "There was an error parsing the configuration file: %s", tostring(err)); + elseif level == "file" then + log("error", "Couldn't read the config file when trying to reload: %s", tostring(err)); + end + else + prosody.events.fire_event("config-reloaded", { + filename = prosody.config_file, + config = config.getconfig(), + }); + end + return ok, (err and tostring(level)..": "..tostring(err)) or nil; + end + + -- Function to reopen logfiles + function prosody.reopen_logfiles() + log("info", "Re-opening log files"); + prosody.events.fire_event("reopen-log-files"); + end + + -- Function to initiate prosody shutdown + function prosody.shutdown(reason, code) + log("info", "Shutting down: %s", reason or "unknown reason"); + prosody.shutdown_reason = reason; + prosody.shutdown_code = code; + prosody.events.fire_event("server-stopping", { + reason = reason; + code = code; + }); + server.setquitting(true); + end +end + +function startup.load_secondary_libraries() + --- Load and initialise core modules + require "util.import" + require "util.xmppstream" + require "core.stanza_router" + require "core.statsmanager" + require "core.hostmanager" + require "core.portmanager" + require "core.modulemanager" + require "core.usermanager" + require "core.rostermanager" + require "core.sessionmanager" + package.loaded['core.componentmanager'] = setmetatable({},{__index=function() + log("warn", "componentmanager is deprecated: %s", debug.traceback():match("\n[^\n]*\n[ \t]*([^\n]*)")); + return function() end + end}); + + require "util.array" + require "util.datetime" + require "util.iterators" + require "util.timer" + require "util.helpers" + + pcall(require, "util.signal") -- Not on Windows + + -- Commented to protect us from + -- the second kind of people + --[[ + pcall(require, "remdebug.engine"); + if remdebug then remdebug.engine.start() end + ]] + + require "util.stanza" + require "util.jid" +end + +function startup.init_http_client() + local http = require "net.http" + local config_ssl = config.get("*", "ssl") or {} + local https_client = config.get("*", "client_https_ssl") + http.default.options.sslctx = require "core.certmanager".create_context("client_https port 0", "client", + { capath = config_ssl.capath, cafile = config_ssl.cafile, verify = "peer", }, https_client); +end + +function startup.init_data_store() + require "core.storagemanager"; +end + +function startup.prepare_to_start() + log("info", "Prosody is using the %s backend for connection handling", server.get_backend()); + -- Signal to modules that we are ready to start + prosody.events.fire_event("server-starting"); + prosody.start_time = os.time(); +end + +function startup.init_global_protection() + -- Catch global accesses + -- luacheck: ignore 212/t + local locked_globals_mt = { + __index = function (t, k) log("warn", "%s", debug.traceback("Attempt to read a non-existent global '"..tostring(k).."'", 2)); end; + __newindex = function (t, k, v) error("Attempt to set a global: "..tostring(k).." = "..tostring(v), 2); end; + }; + + function prosody.unlock_globals() + setmetatable(_G, nil); + end + + function prosody.lock_globals() + setmetatable(_G, locked_globals_mt); + end + + -- And lock now... + prosody.lock_globals(); +end + +function startup.read_version() + -- Try to determine version + local version_file = io.open((CFG_SOURCEDIR or ".").."/prosody.version"); + prosody.version = "unknown"; + if version_file then + prosody.version = version_file:read("*a"):gsub("%s*$", ""); + version_file:close(); + if #prosody.version == 12 and prosody.version:match("^[a-f0-9]+$") then + prosody.version = "hg:"..prosody.version; + end + else + local hg = require"util.mercurial"; + local hgid = hg.check_id(CFG_SOURCEDIR or "."); + if hgid then prosody.version = "hg:" .. hgid; end + end +end + +function startup.log_greeting() + log("info", "Hello and welcome to Prosody version %s", prosody.version); +end + +function startup.notify_started() + prosody.events.fire_event("server-started"); +end + +-- Override logging config (used by prosodyctl) +function startup.force_console_logging() + original_logging_config = config.get("*", "log"); + config.set("*", "log", { { levels = { min = os.getenv("PROSODYCTL_LOG_LEVEL") or "info" }, to = "console" } }); +end + +function startup.switch_user() + -- Switch away from root and into the prosody user -- + -- NOTE: This function is only used by prosodyctl. + -- The prosody process is built with the assumption that + -- it is already started as the appropriate user. + + local want_pposix_version = "0.4.0"; + local have_pposix, pposix = pcall(require, "util.pposix"); + + if have_pposix and pposix then + if pposix._VERSION ~= want_pposix_version then + print(string.format("Unknown version (%s) of binary pposix module, expected %s", + tostring(pposix._VERSION), want_pposix_version)); + os.exit(1); + end + prosody.current_uid = pposix.getuid(); + local arg_root = arg[1] == "--root"; + if arg_root then table.remove(arg, 1); end + if prosody.current_uid == 0 and config.get("*", "run_as_root") ~= true and not arg_root then + -- We haz root! + local desired_user = config.get("*", "prosody_user") or "prosody"; + local desired_group = config.get("*", "prosody_group") or desired_user; + local ok, err = pposix.setgid(desired_group); + if ok then + ok, err = pposix.initgroups(desired_user); + end + if ok then + ok, err = pposix.setuid(desired_user); + if ok then + -- Yay! + prosody.switched_user = true; + end + end + if not prosody.switched_user then + -- Boo! + print("Warning: Couldn't switch to Prosody user/group '"..tostring(desired_user).."'/'"..tostring(desired_group).."': "..tostring(err)); + else + -- Make sure the Prosody user can read the config + local conf, err, errno = io.open(prosody.config_file); + if conf then + conf:close(); + else + print("The config file is not readable by the '"..desired_user.."' user."); + print("Prosody will not be able to read it."); + print("Error was "..err); + os.exit(1); + end + end + end + + -- Set our umask to protect data files + pposix.umask(config.get("*", "umask") or "027"); + pposix.setenv("HOME", prosody.paths.data); + pposix.setenv("PROSODY_CONFIG", prosody.config_file); + else + print("Error: Unable to load pposix module. Check that Prosody is installed correctly.") + print("For more help send the below error to us through https://prosody.im/discuss"); + print(tostring(pposix)) + os.exit(1); + end +end + +function startup.check_unwriteable() + local function test_writeable(filename) + local f, err = io.open(filename, "a"); + if not f then + return false, err; + end + f:close(); + return true; + end + + local unwriteable_files = {}; + if type(original_logging_config) == "string" and original_logging_config:sub(1,1) ~= "*" then + local ok, err = test_writeable(original_logging_config); + if not ok then + table.insert(unwriteable_files, err); + end + elseif type(original_logging_config) == "table" then + for _, rule in ipairs(original_logging_config) do + if rule.filename then + local ok, err = test_writeable(rule.filename); + if not ok then + table.insert(unwriteable_files, err); + end + end + end + end + + if #unwriteable_files > 0 then + print("One of more of the Prosody log files are not"); + print("writeable, please correct the errors and try"); + print("starting prosodyctl again."); + print(""); + for _, err in ipairs(unwriteable_files) do + print(err); + end + print(""); + os.exit(1); + end +end + +function startup.make_host(hostname) + return { + type = "local", + events = prosody.events, + modules = {}, + sessions = {}, + users = require "core.usermanager".new_null_provider(hostname) + }; +end + +function startup.make_dummy_hosts() + -- When running under prosodyctl, we don't want to + -- fully initialize the server, so we populate prosody.hosts + -- with just enough things for most code to work correctly + -- luacheck: ignore 122/hosts + prosody.core_post_stanza = function () end; -- TODO: mod_router! + + for hostname in pairs(config.getconfig()) do + prosody.hosts[hostname] = startup.make_host(hostname); + end +end + +-- prosodyctl only +function startup.prosodyctl() + startup.init_global_state(); + startup.read_config(); + startup.force_console_logging(); + startup.init_logging(); + startup.setup_plugindir(); + startup.setup_datadir(); + startup.chdir(); + startup.read_version(); + startup.switch_user(); + startup.check_dependencies(); + startup.log_dependency_warnings(); + startup.check_unwriteable(); + startup.load_libraries(); + startup.init_http_client(); + startup.make_dummy_hosts(); +end + +function startup.prosody() + -- These actions are in a strict order, as many depend on + -- previous steps to have already been performed + startup.init_global_state(); + startup.read_config(); + startup.init_logging(); + startup.sanity_check(); + startup.sandbox_require(); + startup.set_function_metatable(); + startup.check_dependencies(); + startup.init_logging(); + startup.load_libraries(); + startup.setup_plugindir(); + startup.setup_datadir(); + startup.chdir(); + startup.add_global_prosody_functions(); + startup.read_version(); + startup.log_greeting(); + startup.log_dependency_warnings(); + startup.load_secondary_libraries(); + startup.init_http_client(); + startup.init_data_store(); + startup.init_global_protection(); + startup.prepare_to_start(); + startup.notify_started(); +end + +return startup; diff --git a/util/template.lua b/util/template.lua index 04ebb93d..c11037c5 100644 --- a/util/template.lua +++ b/util/template.lua @@ -4,12 +4,13 @@ local setmetatable = setmetatable; local pairs = pairs; local ipairs = ipairs; local error = error; -local loadstring = loadstring; +local envload = require "util.envload".envload; local debug = debug; local t_remove = table.remove; local parse_xml = require "util.xml".parse; local _ENV = nil; +-- luacheck: std none local function trim_xml(stanza) for i=#stanza,1,-1 do @@ -72,7 +73,7 @@ local function create_cloner(stanza, chunkname) src = src.."local _"..i.."="..lookup[i]..";"; end src = src.."return "..name..";end"; - local f,err = loadstring(src, chunkname); + local f,err = envload(src, chunkname); if not f then error(err); end return f(setmetatable, stanza_mt); end diff --git a/util/termcolours.lua b/util/termcolours.lua index 23c9156b..829d84af 100644 --- a/util/termcolours.lua +++ b/util/termcolours.lua @@ -26,6 +26,7 @@ end local orig_color = windows and windows.get_consolecolor and windows.get_consolecolor(); local _ENV = nil; +-- luacheck: std none local stylemap = { reset = 0; bright = 1, dim = 2, underscore = 4, blink = 5, reverse = 7, hidden = 8; diff --git a/util/throttle.lua b/util/throttle.lua index 1012f78a..d2036e9e 100644 --- a/util/throttle.lua +++ b/util/throttle.lua @@ -3,6 +3,7 @@ local gettime = require "util.time".now local setmetatable = setmetatable; local _ENV = nil; +-- luacheck: std none local throttle = {}; local throttle_mt = { __index = throttle }; diff --git a/util/timer.lua b/util/timer.lua index 7e2e9414..424d44fa 100644 --- a/util/timer.lua +++ b/util/timer.lua @@ -6,78 +6,114 @@ -- COPYING file in the source package for more information. -- +local indexedbheap = require "util.indexedbheap"; +local log = require "util.logger".init("timer"); local server = require "net.server"; -local math_min = math.min -local math_huge = math.huge local get_time = require "util.time".now -local t_insert = table.insert; -local pairs = pairs; +local async = require "util.async"; local type = type; - -local data = {}; -local new_data = {}; +local debug_traceback = debug.traceback; +local tostring = tostring; +local xpcall = xpcall; +local math_max = math.max; local _ENV = nil; +-- luacheck: std none -local _add_task; -if not server.event then - function _add_task(delay, callback) - local current_time = get_time(); - delay = delay + current_time; - if delay >= current_time then - t_insert(new_data, {delay, callback}); - else - local r = callback(current_time); - if r and type(r) == "number" then - return _add_task(r, callback); - end +local _add_task = server.add_task; + +local _server_timer; +local _active_timers = 0; +local h = indexedbheap.create(); +local params = {}; +local next_time = nil; +local _id, _callback, _now, _param; +local function _call() return _callback(_now, _id, _param); end +local function _traceback_handler(err) log("error", "Traceback[timer]: %s", debug_traceback(tostring(err), 2)); end +local function _on_timer(now) + local peek; + while true do + peek = h:peek(); + if peek == nil or peek > now then break; end + local _; + _, _callback, _id = h:pop(); + _now = now; + _param = params[_id]; + params[_id] = nil; + --item(now, id, _param); -- FIXME pcall + local success, err = xpcall(_call, _traceback_handler); + if success and type(err) == "number" then + h:insert(_callback, err + now, _id); -- re-add + params[_id] = _param; end end - server._addtimer(function() - local current_time = get_time(); - if #new_data > 0 then - for _, d in pairs(new_data) do - t_insert(data, d); - end - new_data = {}; - end + if peek ~= nil and _active_timers > 1 and peek == next_time then + -- Another instance of _on_timer already set next_time to the same value, + -- so it should be safe to not renew this timer event + peek = nil; + else + next_time = peek; + end - local next_time = math_huge; - for i, d in pairs(data) do - local t, callback = d[1], d[2]; - if t <= current_time then - data[i] = nil; - local r = callback(current_time); - if type(r) == "number" then - _add_task(r, callback); - next_time = math_min(next_time, r); - end - else - next_time = math_min(next_time, t - current_time); - end - end - return next_time; - end); -else - local event = server.event; - local event_base = server.event_base; - local EVENT_LEAVE = (event.core and event.core.LEAVE) or -1; + if peek then + -- peek is the time of the next event + return peek - now; + end + _active_timers = _active_timers - 1; +end +local function add_task(delay, callback, param) + local current_time = get_time(); + local event_time = current_time + delay; - function _add_task(delay, callback) - local event_handle; - event_handle = event_base:addevent(nil, 0, function () - local ret = callback(get_time()); - if ret then - return 0, ret; - elseif event_handle then - return EVENT_LEAVE; - end + local id = h:insert(callback, event_time); + params[id] = param; + if next_time == nil or event_time < next_time then + next_time = event_time; + if _server_timer then + _server_timer:close(); + _server_timer = nil; + else + _active_timers = _active_timers + 1; end - , delay); + _server_timer = _add_task(next_time - current_time, _on_timer); end + return id; +end +local function stop(id) + params[id] = nil; + local result, item, result_sync = h:remove(id); + local peek = h:peek(); + if peek ~= next_time and _server_timer then + next_time = peek; + _server_timer:close(); + if next_time ~= nil then + _server_timer = _add_task(math_max(next_time - get_time(), 0), _on_timer); + end + end + return result, item, result_sync; +end +local function reschedule(id, delay) + local current_time = get_time(); + local event_time = current_time + delay; + h:reprioritize(id, delay); + if next_time == nil or event_time < next_time then + next_time = event_time; + _add_task(next_time - current_time, _on_timer); + end + return id; +end + +local function sleep(s) + local wait, done = async.waiter(); + add_task(s, done); + wait(); end return { - add_task = _add_task; + add_task = add_task; + stop = stop; + reschedule = reschedule; + sleep = sleep; }; + diff --git a/util/vcard.lua b/util/vcard.lua new file mode 100644 index 00000000..51758c41 --- /dev/null +++ b/util/vcard.lua @@ -0,0 +1,572 @@ +-- Copyright (C) 2011-2014 Kim Alvefur +-- +-- This project is MIT/X11 licensed. Please see the +-- COPYING file in the source package for more information. +-- + +-- TODO +-- Fix folding. + +local st = require "util.stanza"; +local t_insert, t_concat = table.insert, table.concat; +local type = type; +local pairs, ipairs = pairs, ipairs; + +local from_text, to_text, from_xep54, to_xep54; + +local line_sep = "\n"; + +local vCard_dtd; -- See end of file +local vCard4_dtd; + +local function vCard_esc(s) + return s:gsub("[,:;\\]", "\\%1"):gsub("\n","\\n"); +end + +local function vCard_unesc(s) + return s:gsub("\\?[\\nt:;,]", { + ["\\\\"] = "\\", + ["\\n"] = "\n", + ["\\r"] = "\r", + ["\\t"] = "\t", + ["\\:"] = ":", -- FIXME Shouldn't need to espace : in values, just params + ["\\;"] = ";", + ["\\,"] = ",", + [":"] = "\29", + [";"] = "\30", + [","] = "\31", + }); +end + +local function item_to_xep54(item) + local t = st.stanza(item.name, { xmlns = "vcard-temp" }); + + local prop_def = vCard_dtd[item.name]; + if prop_def == "text" then + t:text(item[1]); + elseif type(prop_def) == "table" then + if prop_def.types and item.TYPE then + if type(item.TYPE) == "table" then + for _,v in pairs(prop_def.types) do + for _,typ in pairs(item.TYPE) do + if typ:upper() == v then + t:tag(v):up(); + break; + end + end + end + else + t:tag(item.TYPE:upper()):up(); + end + end + + if prop_def.props then + for _,v in pairs(prop_def.props) do + if item[v] then + t:tag(v):up(); + end + end + end + + if prop_def.value then + t:tag(prop_def.value):text(item[1]):up(); + elseif prop_def.values then + local prop_def_values = prop_def.values; + local repeat_last = prop_def_values.behaviour == "repeat-last" and prop_def_values[#prop_def_values]; + for i=1,#item do + t:tag(prop_def.values[i] or repeat_last):text(item[i]):up(); + end + end + end + + return t; +end + +local function vcard_to_xep54(vCard) + local t = st.stanza("vCard", { xmlns = "vcard-temp" }); + for i=1,#vCard do + t:add_child(item_to_xep54(vCard[i])); + end + return t; +end + +function to_xep54(vCards) + if not vCards[1] or vCards[1].name then + return vcard_to_xep54(vCards) + else + local t = st.stanza("xCard", { xmlns = "vcard-temp" }); + for i=1,#vCards do + t:add_child(vcard_to_xep54(vCards[i])); + end + return t; + end +end + +function from_text(data) + data = data -- unfold and remove empty lines + :gsub("\r\n","\n") + :gsub("\n ", "") + :gsub("\n\n+","\n"); + local vCards = {}; + local current; + for line in data:gmatch("[^\n]+") do + line = vCard_unesc(line); + local name, params, value = line:match("^([-%a]+)(\30?[^\29]*)\29(.*)$"); + value = value:gsub("\29",":"); + if #params > 0 then + local _params = {}; + for k,isval,v in params:gmatch("\30([^=]+)(=?)([^\30]*)") do + k = k:upper(); + local _vt = {}; + for _p in v:gmatch("[^\31]+") do + _vt[#_vt+1]=_p + _vt[_p]=true; + end + if isval == "=" then + _params[k]=_vt; + else + _params[k]=true; + end + end + params = _params; + end + if name == "BEGIN" and value == "VCARD" then + current = {}; + vCards[#vCards+1] = current; + elseif name == "END" and value == "VCARD" then + current = nil; + elseif current and vCard_dtd[name] then + local dtd = vCard_dtd[name]; + local item = { name = name }; + t_insert(current, item); + local up = current; + current = item; + if dtd.types then + for _, t in ipairs(dtd.types) do + t = t:lower(); + if ( params.TYPE and params.TYPE[t] == true) + or params[t] == true then + current.TYPE=t; + end + end + end + if dtd.props then + for _, p in ipairs(dtd.props) do + if params[p] then + if params[p] == true then + current[p]=true; + else + for _, prop in ipairs(params[p]) do + current[p]=prop; + end + end + end + end + end + if dtd == "text" or dtd.value then + t_insert(current, value); + elseif dtd.values then + for p in ("\30"..value):gmatch("\30([^\30]*)") do + t_insert(current, p); + end + end + current = up; + end + end + return vCards; +end + +local function item_to_text(item) + local value = {}; + for i=1,#item do + value[i] = vCard_esc(item[i]); + end + value = t_concat(value, ";"); + + local params = ""; + for k,v in pairs(item) do + if type(k) == "string" and k ~= "name" then + params = params .. (";%s=%s"):format(k, type(v) == "table" and t_concat(v,",") or v); + end + end + + return ("%s%s:%s"):format(item.name, params, value) +end + +local function vcard_to_text(vcard) + local t={}; + t_insert(t, "BEGIN:VCARD") + for i=1,#vcard do + t_insert(t, item_to_text(vcard[i])); + end + t_insert(t, "END:VCARD") + return t_concat(t, line_sep); +end + +function to_text(vCards) + if vCards[1] and vCards[1].name then + return vcard_to_text(vCards) + else + local t = {}; + for i=1,#vCards do + t[i]=vcard_to_text(vCards[i]); + end + return t_concat(t, line_sep); + end +end + +local function from_xep54_item(item) + local prop_name = item.name; + local prop_def = vCard_dtd[prop_name]; + + local prop = { name = prop_name }; + + if prop_def == "text" then + prop[1] = item:get_text(); + elseif type(prop_def) == "table" then + if prop_def.value then --single item + prop[1] = item:get_child_text(prop_def.value) or ""; + elseif prop_def.values then --array + local value_names = prop_def.values; + if value_names.behaviour == "repeat-last" then + for i=1,#item.tags do + t_insert(prop, item.tags[i]:get_text() or ""); + end + else + for i=1,#value_names do + t_insert(prop, item:get_child_text(value_names[i]) or ""); + end + end + elseif prop_def.names then + local names = prop_def.names; + for i=1,#names do + if item:get_child(names[i]) then + prop[1] = names[i]; + break; + end + end + end + + if prop_def.props_verbatim then + for k,v in pairs(prop_def.props_verbatim) do + prop[k] = v; + end + end + + if prop_def.types then + local types = prop_def.types; + prop.TYPE = {}; + for i=1,#types do + if item:get_child(types[i]) then + t_insert(prop.TYPE, types[i]:lower()); + end + end + if #prop.TYPE == 0 then + prop.TYPE = nil; + end + end + + -- A key-value pair, within a key-value pair? + if prop_def.props then + local params = prop_def.props; + for i=1,#params do + local name = params[i] + local data = item:get_child_text(name); + if data then + prop[name] = prop[name] or {}; + t_insert(prop[name], data); + end + end + end + else + return nil + end + + return prop; +end + +local function from_xep54_vCard(vCard) + local tags = vCard.tags; + local t = {}; + for i=1,#tags do + t_insert(t, from_xep54_item(tags[i])); + end + return t +end + +function from_xep54(vCard) + if vCard.attr.xmlns ~= "vcard-temp" then + return nil, "wrong-xmlns"; + end + if vCard.name == "xCard" then -- A collection of vCards + local t = {}; + local vCards = vCard.tags; + for i=1,#vCards do + t[i] = from_xep54_vCard(vCards[i]); + end + return t + elseif vCard.name == "vCard" then -- A single vCard + return from_xep54_vCard(vCard) + end +end + +local vcard4 = { } + +function vcard4:text(node, params, value) -- luacheck: ignore 212/params + self:tag(node:lower()) + -- FIXME params + if type(value) == "string" then + self:tag("text"):text(value):up() + elseif vcard4[node] then + vcard4[node](value); + end + self:up(); +end + +function vcard4.N(value) + for i, k in ipairs(vCard_dtd.N.values) do + value:tag(k):text(value[i]):up(); + end +end + +local xmlns_vcard4 = "urn:ietf:params:xml:ns:vcard-4.0" + +local function item_to_vcard4(item) + local typ = item.name:lower(); + local t = st.stanza(typ, { xmlns = xmlns_vcard4 }); + + local prop_def = vCard4_dtd[typ]; + if prop_def == "text" then + t:tag("text"):text(item[1]):up(); + elseif prop_def == "uri" then + if item.ENCODING and item.ENCODING[1] == 'b' then + t:tag("uri"):text("data:;base64,"):text(item[1]):up(); + else + t:tag("uri"):text(item[1]):up(); + end + elseif type(prop_def) == "table" then + if prop_def.values then + for i, v in ipairs(prop_def.values) do + t:tag(v:lower()):text(item[i] or ""):up(); + end + else + t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) + end + else + t:tag("unsupported",{xmlns="http://zash.se/protocol/vcardlib"}) + end + return t; +end + +local function vcard_to_vcard4xml(vCard) + local t = st.stanza("vcard", { xmlns = xmlns_vcard4 }); + for i=1,#vCard do + t:add_child(item_to_vcard4(vCard[i])); + end + return t; +end + +local function vcards_to_vcard4xml(vCards) + if not vCards[1] or vCards[1].name then + return vcard_to_vcard4xml(vCards) + else + local t = st.stanza("vcards", { xmlns = xmlns_vcard4 }); + for i=1,#vCards do + t:add_child(vcard_to_vcard4xml(vCards[i])); + end + return t; + end +end + +-- This was adapted from http://xmpp.org/extensions/xep-0054.html#dtd +vCard_dtd = { + VERSION = "text", --MUST be 3.0, so parsing is redundant + FN = "text", + N = { + values = { + "FAMILY", + "GIVEN", + "MIDDLE", + "PREFIX", + "SUFFIX", + }, + }, + NICKNAME = "text", + PHOTO = { + props_verbatim = { ENCODING = { "b" } }, + props = { "TYPE" }, + value = "BINVAL", --{ "EXTVAL", }, + }, + BDAY = "text", + ADR = { + types = { + "HOME", + "WORK", + "POSTAL", + "PARCEL", + "DOM", + "INTL", + "PREF", + }, + values = { + "POBOX", + "EXTADD", + "STREET", + "LOCALITY", + "REGION", + "PCODE", + "CTRY", + } + }, + LABEL = { + types = { + "HOME", + "WORK", + "POSTAL", + "PARCEL", + "DOM", + "INTL", + "PREF", + }, + value = "LINE", + }, + TEL = { + types = { + "HOME", + "WORK", + "VOICE", + "FAX", + "PAGER", + "MSG", + "CELL", + "VIDEO", + "BBS", + "MODEM", + "ISDN", + "PCS", + "PREF", + }, + value = "NUMBER", + }, + EMAIL = { + types = { + "HOME", + "WORK", + "INTERNET", + "PREF", + "X400", + }, + value = "USERID", + }, + JABBERID = "text", + MAILER = "text", + TZ = "text", + GEO = { + values = { + "LAT", + "LON", + }, + }, + TITLE = "text", + ROLE = "text", + LOGO = "copy of PHOTO", + AGENT = "text", + ORG = { + values = { + behaviour = "repeat-last", + "ORGNAME", + "ORGUNIT", + } + }, + CATEGORIES = { + values = "KEYWORD", + }, + NOTE = "text", + PRODID = "text", + REV = "text", + SORTSTRING = "text", + SOUND = "copy of PHOTO", + UID = "text", + URL = "text", + CLASS = { + names = { -- The item.name is the value if it's one of these. + "PUBLIC", + "PRIVATE", + "CONFIDENTIAL", + }, + }, + KEY = { + props = { "TYPE" }, + value = "CRED", + }, + DESC = "text", +}; +vCard_dtd.LOGO = vCard_dtd.PHOTO; +vCard_dtd.SOUND = vCard_dtd.PHOTO; + +vCard4_dtd = { + source = "uri", + kind = "text", + xml = "text", + fn = "text", + n = { + values = { + "family", + "given", + "middle", + "prefix", + "suffix", + }, + }, + nickname = "text", + photo = "uri", + bday = "date-and-or-time", + anniversary = "date-and-or-time", + gender = "text", + adr = { + values = { + "pobox", + "ext", + "street", + "locality", + "region", + "code", + "country", + } + }, + tel = "text", + email = "text", + impp = "uri", + lang = "language-tag", + tz = "text", + geo = "uri", + title = "text", + role = "text", + logo = "uri", + org = "text", + member = "uri", + related = "uri", + categories = "text", + note = "text", + prodid = "text", + rev = "timestamp", + sound = "uri", + uid = "uri", + clientpidmap = "number, uuid", + url = "uri", + version = "text", + key = "uri", + fburl = "uri", + caladruri = "uri", + caluri = "uri", +}; + +return { + from_text = from_text; + to_text = to_text; + + from_xep54 = from_xep54; + to_xep54 = to_xep54; + + to_vcard4 = vcards_to_vcard4xml; +}; diff --git a/util/watchdog.lua b/util/watchdog.lua index aa8c6486..516e60e4 100644 --- a/util/watchdog.lua +++ b/util/watchdog.lua @@ -3,6 +3,7 @@ local setmetatable = setmetatable; local os_time = os.time; local _ENV = nil; +-- luacheck: std none local watchdog_methods = {}; local watchdog_mt = { __index = watchdog_methods }; diff --git a/util/x509.lua b/util/x509.lua index f228b201..15cc4d3c 100644 --- a/util/x509.lua +++ b/util/x509.lua @@ -25,6 +25,7 @@ local log = require "util.logger".init("x509"); local s_format = string.format; local _ENV = nil; +-- luacheck: std none local oid_commonname = "2.5.4.3"; -- [LDAP] 2.3 local oid_subjectaltname = "2.5.29.17"; -- [PKIX] 4.2.1.6 diff --git a/util/xml.lua b/util/xml.lua index 733d821a..dac3f6fe 100644 --- a/util/xml.lua +++ b/util/xml.lua @@ -1,8 +1,11 @@ local st = require "util.stanza"; local lxp = require "lxp"; +local t_insert = table.insert; +local t_remove = table.remove; local _ENV = nil; +-- luacheck: std none local parse_xml = (function() local ns_prefixes = { @@ -14,6 +17,21 @@ local parse_xml = (function() --luacheck: ignore 212/self local handler = {}; local stanza = st.stanza("root"); + local namespaces = {}; + local prefixes = {}; + function handler:StartNamespaceDecl(prefix, url) + if prefix ~= nil then + t_insert(namespaces, url); + t_insert(prefixes, prefix); + end + end + function handler:EndNamespaceDecl(prefix) + if prefix ~= nil then + -- we depend on each StartNamespaceDecl having a paired EndNamespaceDecl + t_remove(namespaces); + t_remove(prefixes); + end + end function handler:StartElement(tagname, attr) local curr_ns,name = tagname:match(ns_pattern); if name == "" then @@ -34,7 +52,11 @@ local parse_xml = (function() end end end - stanza:tag(name, attr); + local n = {} + for i=1,#namespaces do + n[prefixes[i]] = namespaces[i]; + end + stanza:tag(name, attr, n); end function handler:CharacterData(data) stanza:text(data); diff --git a/util/xmppstream.lua b/util/xmppstream.lua index 7be63285..8c7851a5 100644 --- a/util/xmppstream.lua +++ b/util/xmppstream.lua @@ -25,6 +25,7 @@ local lxp_supports_bytecount = not not lxp.new({}).getcurrentbytecount; local default_stanza_size_limit = 1024*1024*10; -- 10MB local _ENV = nil; +-- luacheck: std none local new_parser = lxp.new; @@ -47,7 +48,10 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) local cb_streamopened = stream_callbacks.streamopened; local cb_streamclosed = stream_callbacks.streamclosed; - local cb_error = stream_callbacks.error or function(session, e, stanza) error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); end; + local cb_error = stream_callbacks.error or + function(_, e, stanza) + error("XML stream error: "..tostring(e)..(stanza and ": "..tostring(stanza) or ""),2); + end; local cb_handlestanza = stream_callbacks.handlestanza; cb_handleprogress = cb_handleprogress or dummy_cb; @@ -128,6 +132,9 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) end if lxp_supports_xmldecl then function xml_handlers:XmlDecl(version, encoding, standalone) + session.xml_version = version; + session.xml_encoding = encoding; + session.xml_standalone = standalone; if lxp_supports_bytecount then cb_handleprogress(self:getcurrentbytecount()); end @@ -214,7 +221,7 @@ local function new_sax_handlers(session, stream_callbacks, cb_handleprogress) stack = {}; end - local function set_session(stream, new_session) + local function set_session(stream, new_session) -- luacheck: ignore 212/stream session = new_session; end @@ -238,7 +245,7 @@ local function new(session, stream_callbacks, stanza_size_limit) local parser = new_parser(handlers, ns_separator, false); local parse = parser.parse; - function session.open_stream(session, from, to) + function session.open_stream(session, from, to) -- luacheck: ignore 432/session local send = session.sends2s or session.send; local attr = { @@ -264,7 +271,7 @@ local function new(session, stream_callbacks, stanza_size_limit) n_outstanding_bytes = 0; meta.reset(); end, - feed = function (self, data) + feed = function (self, data) -- luacheck: ignore 212/self if lxp_supports_bytecount then n_outstanding_bytes = n_outstanding_bytes + #data; end |